Desalegnn commited on
Commit
1ca96e2
Β·
verified Β·
1 Parent(s): edfe0a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -61
app.py CHANGED
@@ -1,70 +1,282 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
 
 
 
 
 
 
 
 
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
 
24
 
25
- response = ""
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
 
 
 
 
 
 
 
 
33
  ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
68
-
69
- if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import numpy as np
4
+ import faiss
5
+ import torch
6
  import gradio as gr
7
+
8
+ from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering
9
+
10
+
11
+ # -------------------------------------------------------
12
+ # CONFIG
13
+ # -------------------------------------------------------
14
+
15
+ # Embedding model for retrieval
16
+ EMBED_MODEL = "Desalegnn/Desu-snowflake-arctic-embed-l-v2.0-finetuned-amharic-45k"
17
+
18
+ # Extractive QA model (generator/reader)
19
+ QA_MODEL = "Desalegnn/afroxlmr-amharic-qa"
20
+
21
+ # Local files in the Space repo (⚠️ make sure names match what you upload)
22
+ FAISS_PATH = "amharic_faiss.bin" # upload this file
23
+ METADATA_PATH = "passage_meta.jsonl" # upload this file
24
+
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ print("DEVICE:", DEVICE)
27
+
28
+
29
+ # -------------------------------------------------------
30
+ # LOAD MODELS + INDEX + METADATA
31
+ # -------------------------------------------------------
32
+
33
+ # 1) Embedding model
34
+ embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL)
35
+ embed_model = AutoModel.from_pretrained(EMBED_MODEL).to(DEVICE)
36
+ embed_model.eval()
37
+
38
+ # 2) QA model
39
+ qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
40
+ qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL).to(DEVICE)
41
+ qa_model.eval()
42
+
43
+ # 3) FAISS index
44
+ index = faiss.read_index(FAISS_PATH)
45
+ print("FAISS dimension:", index.d)
46
+
47
+ # 4) Passage metadata
48
+ metadata = []
49
+ with open(METADATA_PATH, "r", encoding="utf-8") as f:
50
+ for line in f:
51
+ line = line.strip()
52
+ if line:
53
+ metadata.append(json.loads(line))
54
+
55
+ print("Loaded passages:", len(metadata))
56
+
57
+
58
+ # -------------------------------------------------------
59
+ # EMBEDDING FUNCTION
60
+ # -------------------------------------------------------
61
+
62
+ @torch.no_grad()
63
+ def embed_texts(texts, batch_size=8):
64
+ """
65
+ Embed a list of texts using the Snowflake model (mean-pooled).
66
+ Returns np.ndarray of shape [N, D].
67
+ """
68
+ all_embs = []
69
+
70
+ for i in range(0, len(texts), batch_size):
71
+ batch = texts[i:i + batch_size]
72
+
73
+ enc = embed_tokenizer(
74
+ batch,
75
+ padding=True,
76
+ truncation=True,
77
+ max_length=256,
78
+ return_tensors="pt",
79
+ ).to(DEVICE)
80
+
81
+ out = embed_model(**enc).last_hidden_state # [B, T, D]
82
+ mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
83
+
84
+ summed = (out * mask).sum(dim=1) # [B, D]
85
+ counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
86
+ emb = (summed / counts).cpu().numpy() # [B, D]
87
+
88
+ all_embs.append(emb)
89
+
90
+ return np.vstack(all_embs).astype("float32")
91
+
92
+
93
+ # -------------------------------------------------------
94
+ # RETRIEVAL
95
+ # -------------------------------------------------------
96
+
97
+ def retrieve_top_k(query, k=5):
98
+ """
99
+ 1) Embed query with Snowflake.
100
+ 2) Search FAISS index.
101
+ 3) Return top-k passages and retrieval latency (ms).
102
+ """
103
+ t0 = time.time()
104
+
105
+ query_emb = embed_texts([query]) # [1, D]
106
+ distances, indices = index.search(query_emb, k)
107
+
108
+ ret_latency = (time.time() - t0) * 1000.0 # ms
109
+
110
+ distances = distances[0]
111
+ indices = indices[0]
112
+
113
+ results = []
114
+ for idx, dist in zip(indices, distances):
115
+ if 0 <= idx < len(metadata):
116
+ meta = metadata[idx]
117
+ results.append(
118
+ {
119
+ "id": meta.get("id", idx),
120
+ "text": meta.get("text", ""),
121
+ "score": float(-dist), # larger is better
122
+ }
123
+ )
124
+
125
+ return results, ret_latency
126
+
127
+
128
+ # -------------------------------------------------------
129
+ # EXTRACTIVE QA ON ONE PASSAGE
130
+ # -------------------------------------------------------
131
+
132
+ @torch.no_grad()
133
+ def answer_on_context(question, passage):
134
  """
135
+ Apply AfroXLM-R QA model to (question, passage) and return best span + score.
136
  """
137
+ enc = qa_tokenizer(
138
+ question,
139
+ passage,
140
+ truncation="only_second",
141
+ max_length=384,
142
+ padding="max_length",
143
+ return_offsets_mapping=True,
144
+ return_tensors="pt",
145
+ )
146
 
147
+ input_ids = enc["input_ids"].to(DEVICE)
148
+ attention_mask = enc["attention_mask"].to(DEVICE)
149
+ offset_mapping = enc["offset_mapping"][0].tolist()
150
+ sequence_ids = enc.sequence_ids(0) # 0 = question, 1 = context, None = special
151
 
152
+ outputs = qa_model(input_ids=input_ids, attention_mask=attention_mask)
153
 
154
+ start_logits = outputs.start_logits[0].cpu().numpy()
155
+ end_logits = outputs.end_logits[0].cpu().numpy()
156
 
157
+ # mask out non-context tokens
158
+ for i, sid in enumerate(sequence_ids):
159
+ if sid != 1:
160
+ start_logits[i] = -1e9
161
+ end_logits[i] = -1e9
162
 
163
+ start_idx = int(np.argmax(start_logits))
164
+ end_idx = int(np.argmax(end_logits))
165
+ if end_idx < start_idx:
166
+ end_idx = start_idx
167
+
168
+ # convert to char positions
169
+ start_char, end_char = offset_mapping[start_idx][0], offset_mapping[end_idx][1]
170
+
171
+ if (
172
+ start_char is None
173
+ or end_char is None
174
+ or end_char <= start_char
175
+ or start_char < 0
176
+ or end_char > len(passage)
177
  ):
178
+ answer_text = ""
179
+ else:
180
+ answer_text = passage[start_char:end_char]
181
+
182
+ score = float(start_logits[start_idx] + end_logits[end_idx])
183
+
184
+ return answer_text.strip(), score
185
+
186
+
187
+ # -------------------------------------------------------
188
+ # RAG PIPELINE: RETRIEVE -> EXTRACTIVE QA
189
+ # -------------------------------------------------------
190
+
191
+ def rag_pipeline(question, k=5):
192
+ """
193
+ 1) Retrieve top-k passages.
194
+ 2) Run AfroXLM-R QA on each passage.
195
+ 3) Select best answer by score.
196
+ 4) Return answer, retrieval latency, generator latency, passage snippet.
197
+ """
198
+ # 1) Retrieval
199
+ passages, ret_lat = retrieve_top_k(question, k)
200
+
201
+ if not passages:
202
+ return (
203
+ "**Answer:** αˆ˜αˆ¨αŒƒ αŠ αˆα‰°αŒˆαŠ˜αˆα’",
204
+ f"**Retrieval Latency:** {ret_lat:.2f} ms",
205
+ "**Generator Latency:** 0.00 ms",
206
+ "",
207
+ )
208
+
209
+ # 2) QA on each passage
210
+ t0 = time.time()
211
+
212
+ best_answer = ""
213
+ best_score = -1e9
214
+ best_passage_text = ""
215
+
216
+ for p in passages:
217
+ ctx = p["text"]
218
+ if not ctx.strip():
219
+ continue
220
+
221
+ ans, score = answer_on_context(question, ctx)
222
+ if ans and score > best_score:
223
+ best_score = score
224
+ best_answer = ans
225
+ best_passage_text = ctx
226
+
227
+ gen_lat = (time.time() - t0) * 1000.0 # ms
228
+
229
+ if not best_answer:
230
+ best_answer = "መልሡ αŠ αˆα‰°αŒˆαŠ˜αˆα’"
231
+
232
+ snippet = best_passage_text[:500] + ("..." if len(best_passage_text) > 500 else "")
233
+
234
+ return (
235
+ f"**Answer (AfroXLM-R extractive):** {best_answer}",
236
+ f"**Retrieval Latency:** {ret_lat:.2f} ms",
237
+ f"**Generator Latency (QA):** {gen_lat:.2f} ms",
238
+ snippet,
239
+ )
240
+
241
+
242
+ # -------------------------------------------------------
243
+ # GRADIO APP
244
+ # -------------------------------------------------------
245
+
246
+ def gradio_rag(query, k):
247
+ query = (query or "").strip()
248
+ if not query:
249
+ return "Please type a question.", "", "", ""
250
+ return rag_pipeline(query, int(k))
251
+
252
+
253
+ with gr.Blocks() as app:
254
+ gr.Markdown("<h2>πŸ‡ͺπŸ‡Ή Amharic RAG (Snowflake + AfroXLM-R Extractive QA)</h2>")
255
+ gr.Markdown(
256
+ "Retrieval-Augmented Question Answering: "
257
+ "Snowflake embeddings + FAISS for retrieval, "
258
+ "AfroXLM-R extractive model for answer spans."
259
+ )
260
+
261
+ with gr.Row():
262
+ query = gr.Textbox(
263
+ label="Ask an Amharic question",
264
+ lines=2,
265
+ placeholder="ምሳሌፑ αŠ α‰£α‹­ α‹ˆαŠ•α‹ የቡ αŠα‹ α‹¨αˆšαˆ˜αŠαŒ¨α‹?"
266
+ )
267
+ k = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
268
+
269
+ btn = gr.Button("Run RAG")
270
+
271
+ out_answer = gr.Markdown(label="Answer")
272
+ out_retlat = gr.Markdown(label="Retrieval latency")
273
+ out_genlat = gr.Markdown(label="Generator latency")
274
+ out_passage = gr.Textbox(label="Retrieved passage snippet", lines=10)
275
+
276
+ btn.click(
277
+ gradio_rag,
278
+ inputs=[query, k],
279
+ outputs=[out_answer, out_retlat, out_genlat, out_passage],
280
+ )
281
+
282
+ app.launch()