|
|
import time |
|
|
import json |
|
|
import numpy as np |
|
|
import faiss |
|
|
import torch |
|
|
import gradio as gr |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EMBED_MODEL = "Desalegnn/Desu-snowflake-arctic-embed-l-v2.0-finetuned-amharic-45k" |
|
|
|
|
|
|
|
|
QA_MODEL = "Desalegnn/afroxlmr-amharic-qa" |
|
|
|
|
|
|
|
|
FAISS_PATH = "amharic_faiss.bin" |
|
|
METADATA_PATH = "passage_meta.jsonl" |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print("DEVICE:", DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL) |
|
|
embed_model = AutoModel.from_pretrained(EMBED_MODEL).to(DEVICE) |
|
|
embed_model.eval() |
|
|
|
|
|
|
|
|
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) |
|
|
qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL).to(DEVICE) |
|
|
qa_model.eval() |
|
|
|
|
|
|
|
|
index = faiss.read_index(FAISS_PATH) |
|
|
print("FAISS dimension:", index.d) |
|
|
|
|
|
|
|
|
metadata = [] |
|
|
with open(METADATA_PATH, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
metadata.append(json.loads(line)) |
|
|
|
|
|
print("Loaded passages:", len(metadata)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def embed_texts(texts, batch_size=8): |
|
|
""" |
|
|
Embed a list of texts using the Snowflake model (mean-pooled). |
|
|
Returns np.ndarray of shape [N, D]. |
|
|
""" |
|
|
all_embs = [] |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
|
|
|
enc = embed_tokenizer( |
|
|
batch, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=256, |
|
|
return_tensors="pt", |
|
|
).to(DEVICE) |
|
|
|
|
|
out = embed_model(**enc).last_hidden_state |
|
|
mask = enc["attention_mask"].unsqueeze(-1) |
|
|
|
|
|
summed = (out * mask).sum(dim=1) |
|
|
counts = mask.sum(dim=1).clamp(min=1e-9) |
|
|
emb = (summed / counts).cpu().numpy() |
|
|
|
|
|
all_embs.append(emb) |
|
|
|
|
|
return np.vstack(all_embs).astype("float32") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_top_k(query, k=5): |
|
|
""" |
|
|
1) Embed query with Snowflake. |
|
|
2) Search FAISS index. |
|
|
3) Return top-k passages and retrieval latency (ms). |
|
|
""" |
|
|
t0 = time.time() |
|
|
|
|
|
query_emb = embed_texts([query]) |
|
|
distances, indices = index.search(query_emb, k) |
|
|
|
|
|
ret_latency = (time.time() - t0) * 1000.0 |
|
|
|
|
|
distances = distances[0] |
|
|
indices = indices[0] |
|
|
|
|
|
results = [] |
|
|
for idx, dist in zip(indices, distances): |
|
|
if 0 <= idx < len(metadata): |
|
|
meta = metadata[idx] |
|
|
results.append( |
|
|
{ |
|
|
"id": meta.get("id", idx), |
|
|
"text": meta.get("text", ""), |
|
|
"score": float(-dist), |
|
|
} |
|
|
) |
|
|
|
|
|
return results, ret_latency |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def answer_on_context(question, passage): |
|
|
""" |
|
|
Apply AfroXLM-R QA model to (question, passage) and return best span + score. |
|
|
""" |
|
|
enc = qa_tokenizer( |
|
|
question, |
|
|
passage, |
|
|
truncation="only_second", |
|
|
max_length=384, |
|
|
padding="max_length", |
|
|
return_offsets_mapping=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
input_ids = enc["input_ids"].to(DEVICE) |
|
|
attention_mask = enc["attention_mask"].to(DEVICE) |
|
|
offset_mapping = enc["offset_mapping"][0].tolist() |
|
|
sequence_ids = enc.sequence_ids(0) |
|
|
|
|
|
outputs = qa_model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
start_logits = outputs.start_logits[0].cpu().numpy() |
|
|
end_logits = outputs.end_logits[0].cpu().numpy() |
|
|
|
|
|
|
|
|
for i, sid in enumerate(sequence_ids): |
|
|
if sid != 1: |
|
|
start_logits[i] = -1e9 |
|
|
end_logits[i] = -1e9 |
|
|
|
|
|
start_idx = int(np.argmax(start_logits)) |
|
|
end_idx = int(np.argmax(end_logits)) |
|
|
if end_idx < start_idx: |
|
|
end_idx = start_idx |
|
|
|
|
|
|
|
|
start_char, end_char = offset_mapping[start_idx][0], offset_mapping[end_idx][1] |
|
|
|
|
|
if ( |
|
|
start_char is None |
|
|
or end_char is None |
|
|
or end_char <= start_char |
|
|
or start_char < 0 |
|
|
or end_char > len(passage) |
|
|
): |
|
|
answer_text = "" |
|
|
else: |
|
|
answer_text = passage[start_char:end_char] |
|
|
|
|
|
score = float(start_logits[start_idx] + end_logits[end_idx]) |
|
|
|
|
|
return answer_text.strip(), score |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rag_pipeline(question, k=5): |
|
|
""" |
|
|
1) Retrieve top-k passages. |
|
|
2) Run AfroXLM-R QA on each passage. |
|
|
3) Select best answer by score. |
|
|
4) Return answer, retrieval latency, generator latency, passage snippet. |
|
|
""" |
|
|
|
|
|
passages, ret_lat = retrieve_top_k(question, k) |
|
|
|
|
|
if not passages: |
|
|
return ( |
|
|
"**Answer:** αα¨α α αα°αααα’", |
|
|
f"**Retrieval Latency:** {ret_lat:.2f} ms", |
|
|
"**Generator Latency:** 0.00 ms", |
|
|
"", |
|
|
) |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
|
|
|
best_answer = "" |
|
|
best_score = -1e9 |
|
|
best_passage_text = "" |
|
|
|
|
|
for p in passages: |
|
|
ctx = p["text"] |
|
|
if not ctx.strip(): |
|
|
continue |
|
|
|
|
|
ans, score = answer_on_context(question, ctx) |
|
|
if ans and score > best_score: |
|
|
best_score = score |
|
|
best_answer = ans |
|
|
best_passage_text = ctx |
|
|
|
|
|
gen_lat = (time.time() - t0) * 1000.0 |
|
|
|
|
|
if not best_answer: |
|
|
best_answer = "ααα΅ α αα°αααα’" |
|
|
|
|
|
snippet = best_passage_text[:500] + ("..." if len(best_passage_text) > 500 else "") |
|
|
|
|
|
return ( |
|
|
f"**Answer (AfroXLM-R extractive):** {best_answer}", |
|
|
f"**Retrieval Latency:** {ret_lat:.2f} ms", |
|
|
f"**Generator Latency (QA):** {gen_lat:.2f} ms", |
|
|
snippet, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_rag(query, k): |
|
|
query = (query or "").strip() |
|
|
if not query: |
|
|
return "Please type a question.", "", "", "" |
|
|
return rag_pipeline(query, int(k)) |
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
|
gr.Markdown("<h2>πͺπΉ Amharic RAG (Snowflake + AfroXLM-R Extractive QA)</h2>") |
|
|
gr.Markdown( |
|
|
"Retrieval-Augmented Question Answering: " |
|
|
"Snowflake embeddings + FAISS for retrieval, " |
|
|
"AfroXLM-R extractive model for answer spans." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
query = gr.Textbox( |
|
|
label="Ask an Amharic question", |
|
|
lines=2, |
|
|
placeholder="αα³αα‘ α α£α ααα α¨α΅ αα α¨αααα¨α?" |
|
|
) |
|
|
k = gr.Slider(1, 10, value=5, step=1, label="Top-K passages") |
|
|
|
|
|
btn = gr.Button("Run RAG") |
|
|
|
|
|
out_answer = gr.Markdown(label="Answer") |
|
|
out_retlat = gr.Markdown(label="Retrieval latency") |
|
|
out_genlat = gr.Markdown(label="Generator latency") |
|
|
out_passage = gr.Textbox(label="Retrieved passage snippet", lines=10) |
|
|
|
|
|
btn.click( |
|
|
gradio_rag, |
|
|
inputs=[query, k], |
|
|
outputs=[out_answer, out_retlat, out_genlat, out_passage], |
|
|
) |
|
|
|
|
|
app.launch() |
|
|
|