Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import time | |
| import os | |
| import json | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| from sentence_transformers import SentenceTransformer, util | |
| # --- Path Configuration --- | |
| # Get the absolute path of the directory containing this script | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # Check if running in a Hugging Face Space | |
| is_hf_space = "SPACE_ID" in os.environ | |
| if is_hf_space: | |
| # In a Space, load model from the Hub and data from the repo root | |
| model_path = os.environ.get("MODEL_REPO_ID", "philtoms/minilm-alice-base-rsft-v2") | |
| data_path = "training_triplets.jsonl" | |
| print(f"Running on HF Spaces. Using model from Hub: {model_path}") | |
| else: | |
| # Locally, construct absolute paths based on the script's location | |
| model_path = os.path.join(script_dir, "..", "models", "minilm-alice-base-rsft-v2", "final") | |
| data_path = os.path.join(script_dir, "..", "data", "training_triplets.jsonl") | |
| print(f"Running locally. Using local model at: {model_path}") | |
| # --- Model and Tokenizer Loading --- | |
| try: | |
| # model_path = "sentence-transformers/all-MiniLM-L6-v2" | |
| model_path = "sentence-transformers/multi-qa-mpnet-base-cos-v1" | |
| # model_path = "Qwen/Qwen3-Embedding-0.6B" | |
| # tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # model = AutoModel.from_pretrained(model_path) | |
| model = SentenceTransformer(model_path) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load model from '{model_path}'. Error: {e}") | |
| # --- Dataset Loading --- | |
| if not os.path.exists(data_path): | |
| raise gr.Error(f"Data file not found at '{data_path}'. Please ensure the file exists.") | |
| dataset = [] | |
| with open(data_path, "r") as f: | |
| for line in f: | |
| dataset.append(json.loads(line)) | |
| # Pre-compute corpus embeddings | |
| import re | |
| # def split_into_sentences(text): | |
| # """Splits a paragraph into sentences based on capitalization and punctuation.""" | |
| # # This regex looks for a capital letter, followed by anything that's not a period, | |
| # # exclamation mark, or question mark, and then ends with one of those punctuation marks. | |
| # sentences = re.findall(r'([A-Z][^.!?]*[.!?])', text) | |
| # return sentences | |
| # def create_overlapped_chunks(corpus_documents, chunk_size=2, overlap=1): | |
| # chunked_corpus = [] | |
| # for doc_idx, doc_text in enumerate(corpus_documents): | |
| # sentences = split_into_sentences(doc_text) | |
| # if not sentences: | |
| # continue | |
| # # If there are fewer sentences than chunk_size, just use the whole document as one chunk | |
| # if len(sentences) < chunk_size: | |
| # chunked_corpus.append({ | |
| # "text": doc_text, | |
| # "original_doc_idx": doc_idx, | |
| # "start_sentence_idx": 0, | |
| # "end_sentence_idx": len(sentences) - 1 | |
| # }) | |
| # continue | |
| # for i in range(0, len(sentences) - chunk_size + 1, chunk_size - overlap): | |
| # chunk_sentences = sentences[i : i + chunk_size] | |
| # chunk_text = " ".join(chunk_sentences) | |
| # chunked_corpus.append({ | |
| # "text": chunk_text, | |
| # "original_doc_idx": doc_idx, | |
| # "start_sentence_idx": i, | |
| # "end_sentence_idx": i + chunk_size - 1 | |
| # }) | |
| # return chunked_corpus | |
| # def process_documents_for_chunking(documents): | |
| # chunked_corpus_data = create_overlapped_chunks(documents) | |
| # flat_corpus_chunks = [item["text"] for item in chunked_corpus_data] | |
| # return chunked_corpus_data, flat_corpus_chunks | |
| # Pre-compute corpus embeddings | |
| original_corpus = [item["positive"] for item in dataset] | |
| # chunked_corpus_data, flat_corpus_chunks = process_documents_for_chunking(original_corpus) | |
| # corpus_embeddings = model.encode(flat_corpus_chunks) | |
| corpus_embeddings = model.encode(original_corpus) | |
| # def find_similar(prompt, top_k): | |
| # start_time = time.time() | |
| # prompt_embedding = model.encode(prompt) | |
| # scores = util.dot_score(prompt_embedding, corpus_embeddings)[0].cpu().tolist() | |
| # # Pair scores with the chunked corpus data | |
| # scored_chunks = [] | |
| # for i, score in enumerate(scores): | |
| # scored_chunks.append({ | |
| # "score": score, | |
| # "text": chunked_corpus_data[i]["text"], | |
| # "original_doc_idx": chunked_corpus_data[i]["original_doc_idx"] | |
| # }) | |
| # # Sort by decreasing score | |
| # scored_chunks = sorted(scored_chunks, key=lambda x: x["score"], reverse=True) | |
| # results = [] | |
| # for item in scored_chunks[:top_k]: | |
| # # Return the original document text, not just the chunk | |
| # original_doc_text = original_corpus[item["original_doc_idx"]] | |
| # results.append((item["score"], original_doc_text)) | |
| # end_time = time.time() | |
| # return results, f"{(end_time - start_time) * 1000:.2f} ms" | |
| # with torch.no_grad(): | |
| # encoded_corpus = tokenizer(corpus, padding=True, truncation=True, return_tensors='pt') | |
| # corpus_embeddings = model(**encoded_corpus).last_hidden_state.mean(dim=1) | |
| def find_similar(prompt, top_k): | |
| start_time = time.time() | |
| prompt_embedding = model.encode(prompt) | |
| scores = util.dot_score(prompt_embedding, corpus_embeddings)[0].cpu().tolist() | |
| doc_score_pairs = list(zip(original_corpus, scores)) | |
| #Sort by decreasing score | |
| doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True) | |
| # with torch.no_grad(): | |
| # encoded_prompt = tokenizer(prompt, padding=True, truncation=True, return_tensors='pt') | |
| # prompt_embedding = model(**encoded_prompt).last_hidden_state.mean(dim=1) | |
| # cos_scores = torch.nn.functional.cosine_similarity(prompt_embedding, corpus_embeddings, dim=1) | |
| # top_results = torch.topk(cos_scores, k=int(top_k)) | |
| end_time = time.time() | |
| results = [] | |
| for doc, score in doc_score_pairs[:top_k]: | |
| # for doc, score in doc_score_pairs: | |
| results.append((score, doc)) | |
| return results, f"{(end_time - start_time) * 1000:.2f} ms" | |
| iface = gr.Interface( | |
| fn=find_similar, | |
| inputs=[ | |
| gr.Dropdown( | |
| ["Alice sees White rabbit for the first time", "Alice meets caterpillar", "sad turtle story"], | |
| label="Select a prompt or type your own", | |
| allow_custom_value=True | |
| ), | |
| gr.Slider(1, 20, value=5, step=1, label="Top K") | |
| ], | |
| outputs=[ | |
| gr.Dataframe(headers=[ "Score", "Response"]), | |
| gr.Textbox(label="Time Taken") | |
| ], | |
| title="RSFT Alice Embeddings (Transformers)", | |
| description=f"Enter a prompt to find similar sentences from the corpus." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |