ninjals commited on
Commit
1bd0027
·
1 Parent(s): 5c22f83

Add lazy LLM loading to fix ZeroGPU startup

Browse files
Files changed (1) hide show
  1. model.py +47 -37
model.py CHANGED
@@ -1,72 +1,82 @@
 
1
  import torch
2
  import numpy as np
3
  import pandas as pd
4
  from sentence_transformers import SentenceTransformer, util
 
 
5
 
6
- import os
7
-
8
- # Load saved embeddings
9
  df = pd.read_csv("text_chunks_and_embeddings_df.csv")
10
  df["embedding"] = df["embedding"].apply(lambda x: np.fromstring(x.strip("[]"), sep=" "))
11
  pages_and_chunks = df.to_dict(orient="records")
12
 
13
- # Device fallback
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- print(f"[INFO] Using device: {device}")
16
-
17
- # Load embedding model
18
- embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device=device)
19
-
20
- # Lazy-load the LLM model
21
  llm_model = None
22
  tokenizer = None
23
-
24
-
25
- def load_llm():
26
- global llm_model, tokenizer
27
- if llm_model is None or tokenizer is None:
28
- from transformers import AutoTokenizer, AutoModelForCausalLM
29
- HF_TOKEN = os.getenv("HF_TOKEN")
30
- model_id = "google/gemma-2-2b-it"
31
- print("[INFO] Loading LLM model:", model_id)
32
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
33
- llm_model = AutoModelForCausalLM.from_pretrained(
34
- model_id,
35
- token=HF_TOKEN,
36
- torch_dtype=torch.float16,
37
- attn_implementation="flash_attention_2"
38
- ).to(device)
39
 
40
  def get_embeddings_tensor():
41
- return torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to(device)
42
 
43
- def retrieve_relevant_resources(query, embeddings, model=embedding_model, k=5):
44
- query_emb = model.encode(query, convert_to_tensor=True).to(device)
45
  dot_scores = util.dot_score(query_emb, embeddings)[0]
46
  return torch.topk(dot_scores, k)
47
 
48
- def prompt_formatter(query, context_items):
49
  context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
50
  return tokenizer.apply_chat_template([{
51
  "role": "user",
52
- "content": f"""Base on the following context items, please answer the query.
53
  {context}
54
  User query: {query}
55
  Answer:"""
56
  }], tokenize=False, add_generation_prompt=True)
57
 
 
58
  def ask(query, temperature=0.7, max_new_tokens=256):
59
- load_llm() # Ensure LLM is loaded
60
- scores, indices = retrieve_relevant_resources(query, embeddings=get_embeddings_tensor())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  context = [pages_and_chunks[i] for i in indices]
62
- prompt = prompt_formatter(query, context)
 
 
63
  input_ids = tokenizer(prompt, return_tensors="pt").to(device)
64
- outputs = llm_model.generate(**input_ids, temperature=temperature, do_sample=True, max_new_tokens=max_new_tokens)
 
 
 
 
 
65
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
 
 
66
  if "Answer:" in output_text:
67
  output_text = output_text.split("Answer:")[-1].strip()
68
- output_text = output_text.replace("model", "").strip()
69
  return output_text
70
 
71
 
72
 
 
 
1
+ import os
2
  import torch
3
  import numpy as np
4
  import pandas as pd
5
  from sentence_transformers import SentenceTransformer, util
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import spaces
8
 
9
+ # Load saved embeddings (can be on CPU, it's fast enough)
 
 
10
  df = pd.read_csv("text_chunks_and_embeddings_df.csv")
11
  df["embedding"] = df["embedding"].apply(lambda x: np.fromstring(x.strip("[]"), sep=" "))
12
  pages_and_chunks = df.to_dict(orient="records")
13
 
14
+ # Lazy global variables for models
 
 
 
 
 
 
 
15
  llm_model = None
16
  tokenizer = None
17
+ embedding_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def get_embeddings_tensor():
20
+ return torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to("cuda")
21
 
22
+ def retrieve_relevant_resources(query, embeddings, model, k=5):
23
+ query_emb = model.encode(query, convert_to_tensor=True).to("cuda")
24
  dot_scores = util.dot_score(query_emb, embeddings)[0]
25
  return torch.topk(dot_scores, k)
26
 
27
+ def prompt_formatter(query, context_items, tokenizer):
28
  context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
29
  return tokenizer.apply_chat_template([{
30
  "role": "user",
31
+ "content": f"""Based on the following context items, please answer the query.
32
  {context}
33
  User query: {query}
34
  Answer:"""
35
  }], tokenize=False, add_generation_prompt=True)
36
 
37
+ @spaces.GPU(duration=120)
38
  def ask(query, temperature=0.7, max_new_tokens=256):
39
+ global llm_model, tokenizer, embedding_model
40
+
41
+ # Device setup
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ print(f"[INFO] Using device: {device}")
44
+
45
+ # Load HF token from Secrets (set this in your Space Settings > Secrets)
46
+ HF_TOKEN = os.getenv("HF_TOKEN")
47
+ model_id = "google/gemma-2-2b-it"
48
+
49
+ # Load LLM if not already loaded
50
+ if llm_model is None or tokenizer is None:
51
+ print("[INFO] Loading LLM model:", model_id)
52
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
53
+ llm_model = AutoModelForCausalLM.from_pretrained(model_id, token=HF_TOKEN).to(device)
54
+
55
+ # Load embedding model if not already loaded
56
+ if embedding_model is None:
57
+ print("[INFO] Loading embedding model: all-mpnet-base-v2")
58
+ embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
59
+
60
+ # Retrieve relevant context
61
+ scores, indices = retrieve_relevant_resources(query, get_embeddings_tensor(), embedding_model)
62
  context = [pages_and_chunks[i] for i in indices]
63
+ prompt = prompt_formatter(query, context, tokenizer)
64
+
65
+ # Generate answer
66
  input_ids = tokenizer(prompt, return_tensors="pt").to(device)
67
+ outputs = llm_model.generate(
68
+ **input_ids,
69
+ temperature=temperature,
70
+ do_sample=True,
71
+ max_new_tokens=max_new_tokens
72
+ )
73
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "")
74
+
75
+ # Clean up output
76
  if "Answer:" in output_text:
77
  output_text = output_text.split("Answer:")[-1].strip()
 
78
  return output_text
79
 
80
 
81
 
82
+