File size: 8,119 Bytes
1ca96e2
 
 
 
 
9a54597
1ca96e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a54597
1ca96e2
9a54597
1ca96e2
 
 
 
 
 
 
 
 
9a54597
1ca96e2
 
 
 
9a54597
1ca96e2
9a54597
1ca96e2
 
9a54597
1ca96e2
 
 
 
 
9a54597
1ca96e2
 
 
 
 
 
 
 
 
 
 
 
 
 
9a54597
1ca96e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import time
import json
import numpy as np
import faiss
import torch
import gradio as gr

from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering


# -------------------------------------------------------
# CONFIG
# -------------------------------------------------------

# Embedding model for retrieval
EMBED_MODEL = "Desalegnn/Desu-snowflake-arctic-embed-l-v2.0-finetuned-amharic-45k"

# Extractive QA model (generator/reader)
QA_MODEL = "Desalegnn/afroxlmr-amharic-qa"

# Local files in the Space repo (⚠️ make sure names match what you upload)
FAISS_PATH    = "amharic_faiss.bin"    # upload this file
METADATA_PATH = "passage_meta.jsonl"     # upload this file

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)


# -------------------------------------------------------
# LOAD MODELS + INDEX + METADATA
# -------------------------------------------------------

# 1) Embedding model
embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL)
embed_model     = AutoModel.from_pretrained(EMBED_MODEL).to(DEVICE)
embed_model.eval()

# 2) QA model
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
qa_model     = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL).to(DEVICE)
qa_model.eval()

# 3) FAISS index
index = faiss.read_index(FAISS_PATH)
print("FAISS dimension:", index.d)

# 4) Passage metadata
metadata = []
with open(METADATA_PATH, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            metadata.append(json.loads(line))

print("Loaded passages:", len(metadata))


# -------------------------------------------------------
# EMBEDDING FUNCTION
# -------------------------------------------------------

@torch.no_grad()
def embed_texts(texts, batch_size=8):
    """
    Embed a list of texts using the Snowflake model (mean-pooled).
    Returns np.ndarray of shape [N, D].
    """
    all_embs = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]

        enc = embed_tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=256,
            return_tensors="pt",
        ).to(DEVICE)

        out = embed_model(**enc).last_hidden_state  # [B, T, D]
        mask = enc["attention_mask"].unsqueeze(-1)  # [B, T, 1]

        summed = (out * mask).sum(dim=1)           # [B, D]
        counts = mask.sum(dim=1).clamp(min=1e-9)   # [B, 1]
        emb = (summed / counts).cpu().numpy()      # [B, D]

        all_embs.append(emb)

    return np.vstack(all_embs).astype("float32")


# -------------------------------------------------------
# RETRIEVAL
# -------------------------------------------------------

def retrieve_top_k(query, k=5):
    """
    1) Embed query with Snowflake.
    2) Search FAISS index.
    3) Return top-k passages and retrieval latency (ms).
    """
    t0 = time.time()

    query_emb = embed_texts([query])  # [1, D]
    distances, indices = index.search(query_emb, k)

    ret_latency = (time.time() - t0) * 1000.0  # ms

    distances = distances[0]
    indices   = indices[0]

    results = []
    for idx, dist in zip(indices, distances):
        if 0 <= idx < len(metadata):
            meta = metadata[idx]
            results.append(
                {
                    "id": meta.get("id", idx),
                    "text": meta.get("text", ""),
                    "score": float(-dist),  # larger is better
                }
            )

    return results, ret_latency


# -------------------------------------------------------
# EXTRACTIVE QA ON ONE PASSAGE
# -------------------------------------------------------

@torch.no_grad()
def answer_on_context(question, passage):
    """
    Apply AfroXLM-R QA model to (question, passage) and return best span + score.
    """
    enc = qa_tokenizer(
        question,
        passage,
        truncation="only_second",
        max_length=384,
        padding="max_length",
        return_offsets_mapping=True,
        return_tensors="pt",
    )

    input_ids      = enc["input_ids"].to(DEVICE)
    attention_mask = enc["attention_mask"].to(DEVICE)
    offset_mapping = enc["offset_mapping"][0].tolist()
    sequence_ids   = enc.sequence_ids(0)  # 0 = question, 1 = context, None = special

    outputs = qa_model(input_ids=input_ids, attention_mask=attention_mask)

    start_logits = outputs.start_logits[0].cpu().numpy()
    end_logits   = outputs.end_logits[0].cpu().numpy()

    # mask out non-context tokens
    for i, sid in enumerate(sequence_ids):
        if sid != 1:
            start_logits[i] = -1e9
            end_logits[i]   = -1e9

    start_idx = int(np.argmax(start_logits))
    end_idx   = int(np.argmax(end_logits))
    if end_idx < start_idx:
        end_idx = start_idx

    # convert to char positions
    start_char, end_char = offset_mapping[start_idx][0], offset_mapping[end_idx][1]

    if (
        start_char is None
        or end_char is None
        or end_char <= start_char
        or start_char < 0
        or end_char > len(passage)
    ):
        answer_text = ""
    else:
        answer_text = passage[start_char:end_char]

    score = float(start_logits[start_idx] + end_logits[end_idx])

    return answer_text.strip(), score


# -------------------------------------------------------
# RAG PIPELINE: RETRIEVE -> EXTRACTIVE QA
# -------------------------------------------------------

def rag_pipeline(question, k=5):
    """
    1) Retrieve top-k passages.
    2) Run AfroXLM-R QA on each passage.
    3) Select best answer by score.
    4) Return answer, retrieval latency, generator latency, passage snippet.
    """
    # 1) Retrieval
    passages, ret_lat = retrieve_top_k(question, k)

    if not passages:
        return (
            "**Answer:** αˆ˜αˆ¨αŒƒ αŠ αˆα‰°αŒˆαŠ˜αˆα’",
            f"**Retrieval Latency:** {ret_lat:.2f} ms",
            "**Generator Latency:** 0.00 ms",
            "",
        )

    # 2) QA on each passage
    t0 = time.time()

    best_answer = ""
    best_score  = -1e9
    best_passage_text = ""

    for p in passages:
        ctx = p["text"]
        if not ctx.strip():
            continue

        ans, score = answer_on_context(question, ctx)
        if ans and score > best_score:
            best_score = score
            best_answer = ans
            best_passage_text = ctx

    gen_lat = (time.time() - t0) * 1000.0  # ms

    if not best_answer:
        best_answer = "መልሡ αŠ αˆα‰°αŒˆαŠ˜αˆα’"

    snippet = best_passage_text[:500] + ("..." if len(best_passage_text) > 500 else "")

    return (
        f"**Answer (AfroXLM-R extractive):** {best_answer}",
        f"**Retrieval Latency:** {ret_lat:.2f} ms",
        f"**Generator Latency (QA):** {gen_lat:.2f} ms",
        snippet,
    )


# -------------------------------------------------------
# GRADIO APP
# -------------------------------------------------------

def gradio_rag(query, k):
    query = (query or "").strip()
    if not query:
        return "Please type a question.", "", "", ""
    return rag_pipeline(query, int(k))


with gr.Blocks() as app:
    gr.Markdown("<h2>πŸ‡ͺπŸ‡Ή Amharic RAG (Snowflake + AfroXLM-R Extractive QA)</h2>")
    gr.Markdown(
        "Retrieval-Augmented Question Answering: "
        "Snowflake embeddings + FAISS for retrieval, "
        "AfroXLM-R extractive model for answer spans."
    )

    with gr.Row():
        query = gr.Textbox(
            label="Ask an Amharic question",
            lines=2,
            placeholder="ምሳሌፑ αŠ α‰£α‹­ α‹ˆαŠ•α‹ የቡ αŠα‹ α‹¨αˆšαˆ˜αŠαŒ¨α‹?"
        )
        k = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")

    btn = gr.Button("Run RAG")

    out_answer  = gr.Markdown(label="Answer")
    out_retlat  = gr.Markdown(label="Retrieval latency")
    out_genlat  = gr.Markdown(label="Generator latency")
    out_passage = gr.Textbox(label="Retrieved passage snippet", lines=10)

    btn.click(
        gradio_rag,
        inputs=[query, k],
        outputs=[out_answer, out_retlat, out_genlat, out_passage],
    )

app.launch()