Nutrition-Chatbot / model.py
ninjals's picture
Update model.py, change duration to 30
b05d419 verified
import os
import torch
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
# Load saved embeddings (can be on CPU, it's fast enough)
df = pd.read_csv("text_chunks_and_embeddings_df.csv")
df["embedding"] = df["embedding"].apply(lambda x: np.fromstring(x.strip("[]"), sep=" "))
pages_and_chunks = df.to_dict(orient="records")
# Lazy global variables for models
llm_model = None
tokenizer = None
embedding_model = None
def get_embeddings_tensor():
return torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to("cuda")
def retrieve_relevant_resources(query, embeddings, model, k=5):
query_emb = model.encode(query, convert_to_tensor=True).to("cuda")
dot_scores = util.dot_score(query_emb, embeddings)[0]
return torch.topk(dot_scores, k)
def prompt_formatter(query, context_items, tokenizer):
context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
return tokenizer.apply_chat_template([{
"role": "user",
"content": f"""Based on the following context items, please answer the query.
{context}
User query: {query}
Answer:"""
}], tokenize=False, add_generation_prompt=True)
@spaces.GPU(duration=30)
def ask(query, temperature=0.7, max_new_tokens=384):
global llm_model, tokenizer, embedding_model
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device}")
# Load HF token from Secrets (set this in your Space Settings > Secrets)
HF_TOKEN = os.getenv("HF_TOKEN")
model_id = "google/gemma-2-2b-it"
# Load LLM if not already loaded
if llm_model is None or tokenizer is None:
print("[INFO] Loading LLM model:", model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
llm_model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN).to(device)
# Load embedding model if not already loaded
if embedding_model is None:
print("[INFO] Loading embedding model: all-mpnet-base-v2")
embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
# Retrieve relevant context
scores, indices = retrieve_relevant_resources(query, get_embeddings_tensor(), embedding_model)
context = [pages_and_chunks[i] for i in indices]
prompt = prompt_formatter(query, context, tokenizer)
# Generate answer
input_ids = tokenizer(prompt, return_tensors="pt").to(device)
outputs = llm_model.generate(
**input_ids,
temperature=temperature,
do_sample=True,
max_new_tokens=max_new_tokens
)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
# Clean up output
if "Answer:" in output_text:
output_text = output_text.split("Answer:")[-1].strip()
output_text = output_text.split("model")[-1].strip()
return output_text