import os import torch import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer, util from transformers import AutoTokenizer, AutoModelForCausalLM import spaces # Load saved embeddings (can be on CPU, it's fast enough) df = pd.read_csv("text_chunks_and_embeddings_df.csv") df["embedding"] = df["embedding"].apply(lambda x: np.fromstring(x.strip("[]"), sep=" ")) pages_and_chunks = df.to_dict(orient="records") # Lazy global variables for models llm_model = None tokenizer = None embedding_model = None def get_embeddings_tensor(): return torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to("cuda") def retrieve_relevant_resources(query, embeddings, model, k=5): query_emb = model.encode(query, convert_to_tensor=True).to("cuda") dot_scores = util.dot_score(query_emb, embeddings)[0] return torch.topk(dot_scores, k) def prompt_formatter(query, context_items, tokenizer): context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items]) return tokenizer.apply_chat_template([{ "role": "user", "content": f"""Based on the following context items, please answer the query. {context} User query: {query} Answer:""" }], tokenize=False, add_generation_prompt=True) @spaces.GPU(duration=30) def ask(query, temperature=0.7, max_new_tokens=384): global llm_model, tokenizer, embedding_model # Device setup device = "cuda" if torch.cuda.is_available() else "cpu" print(f"[INFO] Using device: {device}") # Load HF token from Secrets (set this in your Space Settings > Secrets) HF_TOKEN = os.getenv("HF_TOKEN") model_id = "google/gemma-2-2b-it" # Load LLM if not already loaded if llm_model is None or tokenizer is None: print("[INFO] Loading LLM model:", model_id) tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) llm_model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN).to(device) # Load embedding model if not already loaded if embedding_model is None: print("[INFO] Loading embedding model: all-mpnet-base-v2") embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device) # Retrieve relevant context scores, indices = retrieve_relevant_resources(query, get_embeddings_tensor(), embedding_model) context = [pages_and_chunks[i] for i in indices] prompt = prompt_formatter(query, context, tokenizer) # Generate answer input_ids = tokenizer(prompt, return_tensors="pt").to(device) outputs = llm_model.generate( **input_ids, temperature=temperature, do_sample=True, max_new_tokens=max_new_tokens ) output_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "") # Clean up output if "Answer:" in output_text: output_text = output_text.split("Answer:")[-1].strip() output_text = output_text.split("model")[-1].strip() return output_text