RAG / app.py
Desalegnn's picture
Update app.py
1ca96e2 verified
import time
import json
import numpy as np
import faiss
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering
# -------------------------------------------------------
# CONFIG
# -------------------------------------------------------
# Embedding model for retrieval
EMBED_MODEL = "Desalegnn/Desu-snowflake-arctic-embed-l-v2.0-finetuned-amharic-45k"
# Extractive QA model (generator/reader)
QA_MODEL = "Desalegnn/afroxlmr-amharic-qa"
# Local files in the Space repo (⚠️ make sure names match what you upload)
FAISS_PATH = "amharic_faiss.bin" # upload this file
METADATA_PATH = "passage_meta.jsonl" # upload this file
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)
# -------------------------------------------------------
# LOAD MODELS + INDEX + METADATA
# -------------------------------------------------------
# 1) Embedding model
embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL)
embed_model = AutoModel.from_pretrained(EMBED_MODEL).to(DEVICE)
embed_model.eval()
# 2) QA model
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL).to(DEVICE)
qa_model.eval()
# 3) FAISS index
index = faiss.read_index(FAISS_PATH)
print("FAISS dimension:", index.d)
# 4) Passage metadata
metadata = []
with open(METADATA_PATH, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
metadata.append(json.loads(line))
print("Loaded passages:", len(metadata))
# -------------------------------------------------------
# EMBEDDING FUNCTION
# -------------------------------------------------------
@torch.no_grad()
def embed_texts(texts, batch_size=8):
"""
Embed a list of texts using the Snowflake model (mean-pooled).
Returns np.ndarray of shape [N, D].
"""
all_embs = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
enc = embed_tokenizer(
batch,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt",
).to(DEVICE)
out = embed_model(**enc).last_hidden_state # [B, T, D]
mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
summed = (out * mask).sum(dim=1) # [B, D]
counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
emb = (summed / counts).cpu().numpy() # [B, D]
all_embs.append(emb)
return np.vstack(all_embs).astype("float32")
# -------------------------------------------------------
# RETRIEVAL
# -------------------------------------------------------
def retrieve_top_k(query, k=5):
"""
1) Embed query with Snowflake.
2) Search FAISS index.
3) Return top-k passages and retrieval latency (ms).
"""
t0 = time.time()
query_emb = embed_texts([query]) # [1, D]
distances, indices = index.search(query_emb, k)
ret_latency = (time.time() - t0) * 1000.0 # ms
distances = distances[0]
indices = indices[0]
results = []
for idx, dist in zip(indices, distances):
if 0 <= idx < len(metadata):
meta = metadata[idx]
results.append(
{
"id": meta.get("id", idx),
"text": meta.get("text", ""),
"score": float(-dist), # larger is better
}
)
return results, ret_latency
# -------------------------------------------------------
# EXTRACTIVE QA ON ONE PASSAGE
# -------------------------------------------------------
@torch.no_grad()
def answer_on_context(question, passage):
"""
Apply AfroXLM-R QA model to (question, passage) and return best span + score.
"""
enc = qa_tokenizer(
question,
passage,
truncation="only_second",
max_length=384,
padding="max_length",
return_offsets_mapping=True,
return_tensors="pt",
)
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
offset_mapping = enc["offset_mapping"][0].tolist()
sequence_ids = enc.sequence_ids(0) # 0 = question, 1 = context, None = special
outputs = qa_model(input_ids=input_ids, attention_mask=attention_mask)
start_logits = outputs.start_logits[0].cpu().numpy()
end_logits = outputs.end_logits[0].cpu().numpy()
# mask out non-context tokens
for i, sid in enumerate(sequence_ids):
if sid != 1:
start_logits[i] = -1e9
end_logits[i] = -1e9
start_idx = int(np.argmax(start_logits))
end_idx = int(np.argmax(end_logits))
if end_idx < start_idx:
end_idx = start_idx
# convert to char positions
start_char, end_char = offset_mapping[start_idx][0], offset_mapping[end_idx][1]
if (
start_char is None
or end_char is None
or end_char <= start_char
or start_char < 0
or end_char > len(passage)
):
answer_text = ""
else:
answer_text = passage[start_char:end_char]
score = float(start_logits[start_idx] + end_logits[end_idx])
return answer_text.strip(), score
# -------------------------------------------------------
# RAG PIPELINE: RETRIEVE -> EXTRACTIVE QA
# -------------------------------------------------------
def rag_pipeline(question, k=5):
"""
1) Retrieve top-k passages.
2) Run AfroXLM-R QA on each passage.
3) Select best answer by score.
4) Return answer, retrieval latency, generator latency, passage snippet.
"""
# 1) Retrieval
passages, ret_lat = retrieve_top_k(question, k)
if not passages:
return (
"**Answer:** αˆ˜αˆ¨αŒƒ αŠ αˆα‰°αŒˆαŠ˜αˆα’",
f"**Retrieval Latency:** {ret_lat:.2f} ms",
"**Generator Latency:** 0.00 ms",
"",
)
# 2) QA on each passage
t0 = time.time()
best_answer = ""
best_score = -1e9
best_passage_text = ""
for p in passages:
ctx = p["text"]
if not ctx.strip():
continue
ans, score = answer_on_context(question, ctx)
if ans and score > best_score:
best_score = score
best_answer = ans
best_passage_text = ctx
gen_lat = (time.time() - t0) * 1000.0 # ms
if not best_answer:
best_answer = "መልሡ αŠ αˆα‰°αŒˆαŠ˜αˆα’"
snippet = best_passage_text[:500] + ("..." if len(best_passage_text) > 500 else "")
return (
f"**Answer (AfroXLM-R extractive):** {best_answer}",
f"**Retrieval Latency:** {ret_lat:.2f} ms",
f"**Generator Latency (QA):** {gen_lat:.2f} ms",
snippet,
)
# -------------------------------------------------------
# GRADIO APP
# -------------------------------------------------------
def gradio_rag(query, k):
query = (query or "").strip()
if not query:
return "Please type a question.", "", "", ""
return rag_pipeline(query, int(k))
with gr.Blocks() as app:
gr.Markdown("<h2>πŸ‡ͺπŸ‡Ή Amharic RAG (Snowflake + AfroXLM-R Extractive QA)</h2>")
gr.Markdown(
"Retrieval-Augmented Question Answering: "
"Snowflake embeddings + FAISS for retrieval, "
"AfroXLM-R extractive model for answer spans."
)
with gr.Row():
query = gr.Textbox(
label="Ask an Amharic question",
lines=2,
placeholder="ምሳሌፑ αŠ α‰£α‹­ α‹ˆαŠ•α‹ የቡ αŠα‹ α‹¨αˆšαˆ˜αŠαŒ¨α‹?"
)
k = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
btn = gr.Button("Run RAG")
out_answer = gr.Markdown(label="Answer")
out_retlat = gr.Markdown(label="Retrieval latency")
out_genlat = gr.Markdown(label="Generator latency")
out_passage = gr.Textbox(label="Retrieved passage snippet", lines=10)
btn.click(
gradio_rag,
inputs=[query, k],
outputs=[out_answer, out_retlat, out_genlat, out_passage],
)
app.launch()