Abdourakib commited on
Commit
9b22de8
·
verified ·
1 Parent(s): 1b37111

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -115
app.py CHANGED
@@ -1,25 +1,12 @@
1
  import gradio as gr
2
  import numpy as np
3
  import faiss
4
- import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from sentence_transformers import SentenceTransformer
7
 
8
  # -------------------------------
9
- # Load models
10
  # -------------------------------
11
- LM_MODEL_NAME = "distilgpt2"
12
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
-
14
- tokenizer = AutoTokenizer.from_pretrained(LM_MODEL_NAME)
15
- model = AutoModelForCausalLM.from_pretrained(LM_MODEL_NAME)
16
- tokenizer.pad_token = tokenizer.eos_token
17
- model.config.pad_token_id = tokenizer.eos_token_id
18
-
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
- model.to(device)
21
- model.eval()
22
-
23
  embed_model = SentenceTransformer(EMBED_MODEL_NAME)
24
 
25
  # -------------------------------
@@ -65,7 +52,7 @@ examples = [
65
  {
66
  "question": "What is dynamic programming?",
67
  "answer": "Dynamic programming is a problem-solving technique that breaks a problem into overlapping subproblems, stores the results of smaller subproblems, and reuses them to avoid repeated work."
68
- },
69
  ]
70
 
71
  texts = [f"Question: {ex['question']}\nAnswer: {ex['answer']}" for ex in examples]
@@ -76,79 +63,38 @@ texts = [f"Question: {ex['question']}\nAnswer: {ex['answer']}" for ex in example
76
  embeddings = embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
77
  dimension = embeddings.shape[1]
78
 
79
- # Use inner product on normalized embeddings ~ cosine similarity
80
  index = faiss.IndexFlatIP(dimension)
81
  index.add(np.array(embeddings, dtype=np.float32))
82
 
83
  # -------------------------------
84
  # Retrieval threshold
85
  # -------------------------------
86
- # Higher = stricter. You can tune between 0.35 and 0.60
87
  SIMILARITY_THRESHOLD = 0.45
88
 
89
-
90
  # -------------------------------
91
- # Helpers
92
  # -------------------------------
93
- def retrieve_context(question: str, k: int = 3):
94
  question_embedding = embed_model.encode(
95
  [question],
96
  convert_to_numpy=True,
97
  normalize_embeddings=True
98
  )
99
 
100
- scores, indices = index.search(np.array(question_embedding, dtype=np.float32), k)
101
-
102
- retrieved = []
103
- for score, idx in zip(scores[0], indices[0]):
104
- idx = int(idx)
105
- retrieved.append({
106
- "score": float(score),
107
- "question": examples[idx]["question"],
108
- "answer": examples[idx]["answer"],
109
- "text": texts[idx]
110
- })
111
-
112
- return retrieved
113
-
114
-
115
- def clean_answer(text: str) -> str:
116
- if "Answer:" in text:
117
- text = text.split("Answer:")[-1].strip()
118
 
119
- lines = [line.strip() for line in text.splitlines() if line.strip()]
120
- cleaned_lines = []
121
- seen_lines = set()
122
 
123
- for line in lines:
124
- norm = line.lower()
125
- if norm not in seen_lines:
126
- seen_lines.add(norm)
127
- cleaned_lines.append(line)
128
-
129
- text = " ".join(cleaned_lines)
130
-
131
- sentences = [s.strip() for s in text.split(".") if s.strip()]
132
- unique_sentences = []
133
- seen_sentences = set()
134
-
135
- for s in sentences:
136
- norm = s.lower()
137
- if norm not in seen_sentences:
138
- seen_sentences.add(norm)
139
- unique_sentences.append(s)
140
-
141
- if unique_sentences:
142
- text = ". ".join(unique_sentences) + "."
143
-
144
- return text.strip()
145
 
146
 
147
  def fallback_message() -> str:
148
  return (
149
  "I do not have enough reliable information in my current knowledge base to answer that question well. "
150
- "Please ask about basic computer science topics like recursion, stacks, queues, arrays, linked lists, "
151
- "binary search, Big O notation, processes, threads, or hash tables."
152
  )
153
 
154
 
@@ -157,59 +103,12 @@ def cs_tutor_app(question: str) -> str:
157
  if not question:
158
  return "Please enter a computer science question."
159
 
160
- retrieved = retrieve_context(question, k=3)
161
- best_score = retrieved[0]["score"]
162
 
163
- # If best match is too weak, do not hallucinate
164
  if best_score < SIMILARITY_THRESHOLD:
165
  return fallback_message()
166
 
167
- context = "\n\n".join(
168
- [f"Question: {item['question']}\nAnswer: {item['answer']}" for item in retrieved]
169
- )
170
-
171
- prompt = f"""You are a helpful computer science tutor.
172
-
173
- Use the examples below to answer the user's question clearly and simply.
174
- Write a short beginner-friendly answer in 2 to 4 sentences.
175
- Do not repeat yourself.
176
- Do not include unrelated information.
177
- Only answer if the examples are relevant.
178
-
179
- Examples:
180
- {context}
181
-
182
- Question: {question}
183
-
184
- Answer:"""
185
-
186
- inputs = tokenizer(
187
- prompt,
188
- return_tensors="pt",
189
- truncation=True,
190
- max_length=1024
191
- ).to(device)
192
-
193
- with torch.no_grad():
194
- output = model.generate(
195
- **inputs,
196
- max_new_tokens=80,
197
- do_sample=True,
198
- temperature=0.7,
199
- top_p=0.9,
200
- repetition_penalty=1.2,
201
- no_repeat_ngram_size=3,
202
- pad_token_id=tokenizer.eos_token_id,
203
- eos_token_id=tokenizer.eos_token_id
204
- )
205
-
206
- response = tokenizer.decode(output[0], skip_special_tokens=True)
207
- response = clean_answer(response)
208
-
209
- if len(response) < 20:
210
- return fallback_message()
211
-
212
- return response
213
 
214
 
215
  # -------------------------------
@@ -234,4 +133,3 @@ demo = gr.Interface(
234
  )
235
 
236
  demo.launch()
237
-
 
1
  import gradio as gr
2
  import numpy as np
3
  import faiss
 
 
4
  from sentence_transformers import SentenceTransformer
5
 
6
  # -------------------------------
7
+ # Embedding model
8
  # -------------------------------
 
9
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
 
 
 
 
 
 
 
 
 
 
10
  embed_model = SentenceTransformer(EMBED_MODEL_NAME)
11
 
12
  # -------------------------------
 
52
  {
53
  "question": "What is dynamic programming?",
54
  "answer": "Dynamic programming is a problem-solving technique that breaks a problem into overlapping subproblems, stores the results of smaller subproblems, and reuses them to avoid repeated work."
55
+ }
56
  ]
57
 
58
  texts = [f"Question: {ex['question']}\nAnswer: {ex['answer']}" for ex in examples]
 
63
  embeddings = embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
64
  dimension = embeddings.shape[1]
65
 
66
+ # Inner product on normalized vectors ~= cosine similarity
67
  index = faiss.IndexFlatIP(dimension)
68
  index.add(np.array(embeddings, dtype=np.float32))
69
 
70
  # -------------------------------
71
  # Retrieval threshold
72
  # -------------------------------
 
73
  SIMILARITY_THRESHOLD = 0.45
74
 
 
75
  # -------------------------------
76
+ # Helper functions
77
  # -------------------------------
78
+ def retrieve_best_match(question: str):
79
  question_embedding = embed_model.encode(
80
  [question],
81
  convert_to_numpy=True,
82
  normalize_embeddings=True
83
  )
84
 
85
+ scores, indices = index.search(np.array(question_embedding, dtype=np.float32), 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ best_score = float(scores[0][0])
88
+ best_idx = int(indices[0][0])
 
89
 
90
+ return best_score, examples[best_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  def fallback_message() -> str:
94
  return (
95
  "I do not have enough reliable information in my current knowledge base to answer that question well. "
96
+ "Please ask about topics like recursion, stacks, queues, arrays, linked lists, binary search, Big O notation, "
97
+ "processes, threads, hash tables, or dynamic programming."
98
  )
99
 
100
 
 
103
  if not question:
104
  return "Please enter a computer science question."
105
 
106
+ best_score, best_match = retrieve_best_match(question)
 
107
 
 
108
  if best_score < SIMILARITY_THRESHOLD:
109
  return fallback_message()
110
 
111
+ return best_match["answer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  # -------------------------------
 
133
  )
134
 
135
  demo.launch()