File size: 3,012 Bytes
1bd0027
3e314bf
 
 
 
1bd0027
 
5c22f83
1bd0027
22bb127
3e314bf
f50eeae
0a2f11a
1bd0027
465426d
 
1bd0027
465426d
 
1bd0027
465426d
1bd0027
 
3e314bf
 
 
1bd0027
3e314bf
 
 
1bd0027
3e314bf
 
 
 
 
b05d419
e177d74
1bd0027
 
 
 
 
 
 
 
8177792
1bd0027
 
 
 
 
 
 
 
 
 
 
 
 
 
3e314bf
1bd0027
 
 
22bb127
1bd0027
 
 
 
 
 
3e314bf
1bd0027
 
3e314bf
 
6d4ba99
3e314bf
22bb127
f50eeae
465426d
1bd0027
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import torch
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces

# Load saved embeddings (can be on CPU, it's fast enough)
df = pd.read_csv("text_chunks_and_embeddings_df.csv")
df["embedding"] = df["embedding"].apply(lambda x: np.fromstring(x.strip("[]"), sep=" "))
pages_and_chunks = df.to_dict(orient="records")

# Lazy global variables for models
llm_model = None
tokenizer = None
embedding_model = None

def get_embeddings_tensor():
    return torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to("cuda")

def retrieve_relevant_resources(query, embeddings, model, k=5):
    query_emb = model.encode(query, convert_to_tensor=True).to("cuda")
    dot_scores = util.dot_score(query_emb, embeddings)[0]
    return torch.topk(dot_scores, k)

def prompt_formatter(query, context_items, tokenizer):
    context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
    return tokenizer.apply_chat_template([{
        "role": "user",
        "content": f"""Based on the following context items, please answer the query.
{context}
User query: {query}
Answer:"""
    }], tokenize=False, add_generation_prompt=True)

@spaces.GPU(duration=30)
def ask(query, temperature=0.7, max_new_tokens=384):
    global llm_model, tokenizer, embedding_model

    # Device setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[INFO] Using device: {device}")

    # Load HF token from Secrets (set this in your Space Settings > Secrets)
    HF_TOKEN = os.getenv("HF_TOKEN")
    model_id = "google/gemma-2-2b-it"

    # Load LLM if not already loaded
    if llm_model is None or tokenizer is None:
        print("[INFO] Loading LLM model:", model_id)
        tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
        llm_model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN).to(device)

    # Load embedding model if not already loaded
    if embedding_model is None:
        print("[INFO] Loading embedding model: all-mpnet-base-v2")
        embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)

    # Retrieve relevant context
    scores, indices = retrieve_relevant_resources(query, get_embeddings_tensor(), embedding_model)
    context = [pages_and_chunks[i] for i in indices]
    prompt = prompt_formatter(query, context, tokenizer)

    # Generate answer
    input_ids = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = llm_model.generate(
        **input_ids,
        temperature=temperature,
        do_sample=True,
        max_new_tokens=max_new_tokens
    )
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")

    # Clean up output
    if "Answer:" in output_text:
        output_text = output_text.split("Answer:")[-1].strip()
        output_text = output_text.split("model")[-1].strip()
    return output_text