import time import json import numpy as np import faiss import torch import gradio as gr from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering # ------------------------------------------------------- # CONFIG # ------------------------------------------------------- # Embedding model for retrieval EMBED_MODEL = "Desalegnn/Desu-snowflake-arctic-embed-l-v2.0-finetuned-amharic-45k" # Extractive QA model (generator/reader) QA_MODEL = "Desalegnn/afroxlmr-amharic-qa" # Local files in the Space repo (⚠️ make sure names match what you upload) FAISS_PATH = "amharic_faiss.bin" # upload this file METADATA_PATH = "passage_meta.jsonl" # upload this file DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("DEVICE:", DEVICE) # ------------------------------------------------------- # LOAD MODELS + INDEX + METADATA # ------------------------------------------------------- # 1) Embedding model embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL) embed_model = AutoModel.from_pretrained(EMBED_MODEL).to(DEVICE) embed_model.eval() # 2) QA model qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL).to(DEVICE) qa_model.eval() # 3) FAISS index index = faiss.read_index(FAISS_PATH) print("FAISS dimension:", index.d) # 4) Passage metadata metadata = [] with open(METADATA_PATH, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: metadata.append(json.loads(line)) print("Loaded passages:", len(metadata)) # ------------------------------------------------------- # EMBEDDING FUNCTION # ------------------------------------------------------- @torch.no_grad() def embed_texts(texts, batch_size=8): """ Embed a list of texts using the Snowflake model (mean-pooled). Returns np.ndarray of shape [N, D]. """ all_embs = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] enc = embed_tokenizer( batch, padding=True, truncation=True, max_length=256, return_tensors="pt", ).to(DEVICE) out = embed_model(**enc).last_hidden_state # [B, T, D] mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1] summed = (out * mask).sum(dim=1) # [B, D] counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1] emb = (summed / counts).cpu().numpy() # [B, D] all_embs.append(emb) return np.vstack(all_embs).astype("float32") # ------------------------------------------------------- # RETRIEVAL # ------------------------------------------------------- def retrieve_top_k(query, k=5): """ 1) Embed query with Snowflake. 2) Search FAISS index. 3) Return top-k passages and retrieval latency (ms). """ t0 = time.time() query_emb = embed_texts([query]) # [1, D] distances, indices = index.search(query_emb, k) ret_latency = (time.time() - t0) * 1000.0 # ms distances = distances[0] indices = indices[0] results = [] for idx, dist in zip(indices, distances): if 0 <= idx < len(metadata): meta = metadata[idx] results.append( { "id": meta.get("id", idx), "text": meta.get("text", ""), "score": float(-dist), # larger is better } ) return results, ret_latency # ------------------------------------------------------- # EXTRACTIVE QA ON ONE PASSAGE # ------------------------------------------------------- @torch.no_grad() def answer_on_context(question, passage): """ Apply AfroXLM-R QA model to (question, passage) and return best span + score. """ enc = qa_tokenizer( question, passage, truncation="only_second", max_length=384, padding="max_length", return_offsets_mapping=True, return_tensors="pt", ) input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) offset_mapping = enc["offset_mapping"][0].tolist() sequence_ids = enc.sequence_ids(0) # 0 = question, 1 = context, None = special outputs = qa_model(input_ids=input_ids, attention_mask=attention_mask) start_logits = outputs.start_logits[0].cpu().numpy() end_logits = outputs.end_logits[0].cpu().numpy() # mask out non-context tokens for i, sid in enumerate(sequence_ids): if sid != 1: start_logits[i] = -1e9 end_logits[i] = -1e9 start_idx = int(np.argmax(start_logits)) end_idx = int(np.argmax(end_logits)) if end_idx < start_idx: end_idx = start_idx # convert to char positions start_char, end_char = offset_mapping[start_idx][0], offset_mapping[end_idx][1] if ( start_char is None or end_char is None or end_char <= start_char or start_char < 0 or end_char > len(passage) ): answer_text = "" else: answer_text = passage[start_char:end_char] score = float(start_logits[start_idx] + end_logits[end_idx]) return answer_text.strip(), score # ------------------------------------------------------- # RAG PIPELINE: RETRIEVE -> EXTRACTIVE QA # ------------------------------------------------------- def rag_pipeline(question, k=5): """ 1) Retrieve top-k passages. 2) Run AfroXLM-R QA on each passage. 3) Select best answer by score. 4) Return answer, retrieval latency, generator latency, passage snippet. """ # 1) Retrieval passages, ret_lat = retrieve_top_k(question, k) if not passages: return ( "**Answer:** መረጃ አልተገኘም።", f"**Retrieval Latency:** {ret_lat:.2f} ms", "**Generator Latency:** 0.00 ms", "", ) # 2) QA on each passage t0 = time.time() best_answer = "" best_score = -1e9 best_passage_text = "" for p in passages: ctx = p["text"] if not ctx.strip(): continue ans, score = answer_on_context(question, ctx) if ans and score > best_score: best_score = score best_answer = ans best_passage_text = ctx gen_lat = (time.time() - t0) * 1000.0 # ms if not best_answer: best_answer = "መልስ አልተገኘም።" snippet = best_passage_text[:500] + ("..." if len(best_passage_text) > 500 else "") return ( f"**Answer (AfroXLM-R extractive):** {best_answer}", f"**Retrieval Latency:** {ret_lat:.2f} ms", f"**Generator Latency (QA):** {gen_lat:.2f} ms", snippet, ) # ------------------------------------------------------- # GRADIO APP # ------------------------------------------------------- def gradio_rag(query, k): query = (query or "").strip() if not query: return "Please type a question.", "", "", "" return rag_pipeline(query, int(k)) with gr.Blocks() as app: gr.Markdown("

🇪🇹 Amharic RAG (Snowflake + AfroXLM-R Extractive QA)

") gr.Markdown( "Retrieval-Augmented Question Answering: " "Snowflake embeddings + FAISS for retrieval, " "AfroXLM-R extractive model for answer spans." ) with gr.Row(): query = gr.Textbox( label="Ask an Amharic question", lines=2, placeholder="ምሳሌ፡ አባይ ወንዝ የት ነው የሚመነጨው?" ) k = gr.Slider(1, 10, value=5, step=1, label="Top-K passages") btn = gr.Button("Run RAG") out_answer = gr.Markdown(label="Answer") out_retlat = gr.Markdown(label="Retrieval latency") out_genlat = gr.Markdown(label="Generator latency") out_passage = gr.Textbox(label="Retrieved passage snippet", lines=10) btn.click( gradio_rag, inputs=[query, k], outputs=[out_answer, out_retlat, out_genlat, out_passage], ) app.launch()