Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer, util | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import spaces | |
| # Load saved embeddings (can be on CPU, it's fast enough) | |
| df = pd.read_csv("text_chunks_and_embeddings_df.csv") | |
| df["embedding"] = df["embedding"].apply(lambda x: np.fromstring(x.strip("[]"), sep=" ")) | |
| pages_and_chunks = df.to_dict(orient="records") | |
| # Lazy global variables for models | |
| llm_model = None | |
| tokenizer = None | |
| embedding_model = None | |
| def get_embeddings_tensor(): | |
| return torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to("cuda") | |
| def retrieve_relevant_resources(query, embeddings, model, k=5): | |
| query_emb = model.encode(query, convert_to_tensor=True).to("cuda") | |
| dot_scores = util.dot_score(query_emb, embeddings)[0] | |
| return torch.topk(dot_scores, k) | |
| def prompt_formatter(query, context_items, tokenizer): | |
| context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items]) | |
| return tokenizer.apply_chat_template([{ | |
| "role": "user", | |
| "content": f"""Based on the following context items, please answer the query. | |
| {context} | |
| User query: {query} | |
| Answer:""" | |
| }], tokenize=False, add_generation_prompt=True) | |
| def ask(query, temperature=0.7, max_new_tokens=384): | |
| global llm_model, tokenizer, embedding_model | |
| # Device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[INFO] Using device: {device}") | |
| # Load HF token from Secrets (set this in your Space Settings > Secrets) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| model_id = "google/gemma-2-2b-it" | |
| # Load LLM if not already loaded | |
| if llm_model is None or tokenizer is None: | |
| print("[INFO] Loading LLM model:", model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) | |
| llm_model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN).to(device) | |
| # Load embedding model if not already loaded | |
| if embedding_model is None: | |
| print("[INFO] Loading embedding model: all-mpnet-base-v2") | |
| embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device) | |
| # Retrieve relevant context | |
| scores, indices = retrieve_relevant_resources(query, get_embeddings_tensor(), embedding_model) | |
| context = [pages_and_chunks[i] for i in indices] | |
| prompt = prompt_formatter(query, context, tokenizer) | |
| # Generate answer | |
| input_ids = tokenizer(prompt, return_tensors="pt").to(device) | |
| outputs = llm_model.generate( | |
| **input_ids, | |
| temperature=temperature, | |
| do_sample=True, | |
| max_new_tokens=max_new_tokens | |
| ) | |
| output_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "") | |
| # Clean up output | |
| if "Answer:" in output_text: | |
| output_text = output_text.split("Answer:")[-1].strip() | |
| output_text = output_text.split("model")[-1].strip() | |
| return output_text | |