jk12p commited on
Commit
da4682c
Β·
verified Β·
1 Parent(s): ea4de28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -10,19 +10,19 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
10
  # --- CONFIG ---
11
  HF_TOKEN = os.environ["HF_TOKEN"] # Taken from Hugging Face Space secrets
12
 
13
- # Load tokenizer and model with token
14
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=HF_TOKEN)
15
  model = AutoModelForCausalLM.from_pretrained(
16
- "google/gemma-2b-it",
17
  torch_dtype=torch.float16,
18
- device_map="auto",
19
  token=HF_TOKEN
20
  )
21
 
22
  # Load sentence transformer for embeddings
23
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
24
 
25
- st.title("πŸ” RAG App using πŸ€– Gemma 2B")
26
 
27
  uploaded_file = st.file_uploader("πŸ“„ Upload a PDF or TXT file", type=["pdf", "txt"])
28
 
@@ -50,7 +50,7 @@ def create_faiss_index(chunks):
50
  return index, embeddings
51
 
52
  # Retrieve top-k chunks
53
- def retrieve_chunks(query, chunks, index, embeddings, k=5): # increased k
54
  query_embedding = embedder.encode([query])
55
  D, I = index.search(np.array(query_embedding), k)
56
  return [chunks[i] for i in I[0]]
@@ -73,12 +73,14 @@ if uploaded_file:
73
  with st.spinner("Thinking..."):
74
  context = "\n".join(retrieve_chunks(user_question, chunks, index, embeddings))
75
 
76
- # Improved prompt
77
  prompt = (
78
- f"You are an expert assistant. Use the following context to answer the user's question.\n"
79
- f"If the answer (e.g., a name) is mentioned anywhere in the context, extract it precisely.\n"
80
- f"If it's not found, say clearly: 'Name not found.'\n\n"
81
- f"Context:\n{context}\n\nQuestion: {user_question}\nAnswer:"
 
 
82
  )
83
 
84
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
@@ -86,14 +88,20 @@ if uploaded_file:
86
  with torch.no_grad():
87
  outputs = model.generate(
88
  input_ids,
89
- max_new_tokens=256, # Using max_new_tokens instead of max_length
90
  num_return_sequences=1,
91
- temperature=0.7,
92
- do_sample=False
 
93
  )
94
 
95
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
- answer = generated_text.split("Answer:")[-1].strip()
 
 
 
 
 
97
 
98
  st.markdown("### 🧠 Answer:")
99
  st.success(answer)
 
10
  # --- CONFIG ---
11
  HF_TOKEN = os.environ["HF_TOKEN"] # Taken from Hugging Face Space secrets
12
 
13
+ # Load tokenizer and model (replaced Gemma 2B with Phi-2)
14
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", token=HF_TOKEN)
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ "microsoft/phi-2",
17
  torch_dtype=torch.float16,
18
+ device_map="auto",
19
  token=HF_TOKEN
20
  )
21
 
22
  # Load sentence transformer for embeddings
23
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
24
 
25
+ st.title("πŸ” RAG App using πŸ€– Phi-2")
26
 
27
  uploaded_file = st.file_uploader("πŸ“„ Upload a PDF or TXT file", type=["pdf", "txt"])
28
 
 
50
  return index, embeddings
51
 
52
  # Retrieve top-k chunks
53
+ def retrieve_chunks(query, chunks, index, embeddings, k=5):
54
  query_embedding = embedder.encode([query])
55
  D, I = index.search(np.array(query_embedding), k)
56
  return [chunks[i] for i in I[0]]
 
73
  with st.spinner("Thinking..."):
74
  context = "\n".join(retrieve_chunks(user_question, chunks, index, embeddings))
75
 
76
+ # Updated prompt for Phi-2's instruction style
77
  prompt = (
78
+ f"Instruction: Answer the following question using only the context provided. "
79
+ f"Extract specific information directly from the context when available. "
80
+ f"If the answer is not in the context, respond with 'Information not found.'\n\n"
81
+ f"Context:\n{context}\n\n"
82
+ f"Question: {user_question}\n\n"
83
+ f"Answer: "
84
  )
85
 
86
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
 
88
  with torch.no_grad():
89
  outputs = model.generate(
90
  input_ids,
91
+ max_new_tokens=256, # Keep using max_new_tokens as fixed before
92
  num_return_sequences=1,
93
+ temperature=0.2, # Lower temperature for more focused answers
94
+ do_sample=True, # Enable sampling for more natural responses
95
+ top_p=0.9, # Add top_p sampling for better quality
96
  )
97
 
98
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+
100
+ # Extract the answer part - adapt based on Phi-2's output format
101
+ if "Answer:" in generated_text:
102
+ answer = generated_text.split("Answer:")[-1].strip()
103
+ else:
104
+ answer = generated_text.replace(prompt, "").strip()
105
 
106
  st.markdown("### 🧠 Answer:")
107
  st.success(answer)