| import os |
| import requests |
| from pathlib import Path |
| from typing import List, Dict, Any, Tuple |
| import time |
|
|
| import numpy as np |
| import torch |
| from datasets import load_dataset |
| from sentence_transformers import SentenceTransformer |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| |
| |
| |
| REPO_ID = os.getenv("HF_DATASET_REPO", "AshJem/NOLEJ1") |
|
|
| FAISS_URL = os.getenv( |
| "FAISS_URL", |
| "https://huggingface.co/datasets/AshJem/NOLEJ1/resolve/main/global.faiss" |
| ) |
|
|
| TOP_K = int(os.getenv("TOP_K", "8")) |
| RERANK_TOP_N = int(os.getenv("RERANK_TOP_N", "4")) |
| MIN_CONTEXT_CHARS = int(os.getenv("MIN_CONTEXT_CHARS", "250")) |
|
|
| RERANK = os.getenv("RERANK", "1") == "1" |
| USE_OPENAI = os.getenv("USE_OPENAI", "1") == "1" |
|
|
| |
| |
| |
| from openai import AsyncOpenAI |
| _openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) |
|
|
| |
| |
| |
| WRITE_DIR = Path(os.getenv("WRITE_DIR", "./tmp")) |
| WRITE_DIR.mkdir(parents=True, exist_ok=True) |
| FAISS_LOCAL = WRITE_DIR / "global.faiss" |
|
|
|
|
| def _download_faiss_if_needed(): |
| if FAISS_LOCAL.exists(): |
| return |
| with requests.get(FAISS_URL, stream=True, timeout=120) as r: |
| r.raise_for_status() |
| with open(FAISS_LOCAL, "wb") as f: |
| for chunk in r.iter_content(chunk_size=8192): |
| f.write(chunk) |
|
|
|
|
| _download_faiss_if_needed() |
|
|
| dataset = load_dataset(REPO_ID, split="train", keep_in_memory=True) |
| dataset.load_faiss_index("embedding", str(FAISS_LOCAL)) |
|
|
| |
| embedder = SentenceTransformer("BAAI/bge-m3") |
|
|
| if RERANK: |
| |
| reranker_model_id = "BAAI/bge-reranker-v2-m3" |
| reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_model_id) |
| reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_model_id) |
| reranker_model.eval() |
| else: |
| reranker_tokenizer = None |
| reranker_model = None |
|
|
|
|
| def embed_query(text: str) -> np.ndarray: |
| return embedder.encode(text, normalize_embeddings=True) |
|
|
|
|
| def retrieve_topk(query: str, k: int = TOP_K) -> List[Dict[str, Any]]: |
| qv = embed_query(query) |
| scores, ret = dataset.get_nearest_examples("embedding", qv, k=k) |
|
|
| items = [] |
| for i in range(len(ret["text"])): |
| s = float(scores[i]) if scores is not None else None |
| items.append({ |
| "text": ret["text"][i], |
| "metadata": ret["metadata"][i], |
| "score": s, |
| "final_score": s, |
| }) |
| return items |
|
|
|
|
| def rerank(query: str, items: List[Dict[str, Any]], top_n: int = RERANK_TOP_N) -> List[Dict[str, Any]]: |
| if (not RERANK) or (not items): |
| return items[:top_n] |
|
|
| inputs = [f"{query} [SEP] {it['text']}" for it in items] |
| enc = reranker_tokenizer(inputs, padding=True, truncation=True, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| logits = reranker_model(**enc).logits.squeeze() |
|
|
| rr_scores = [float(logits)] if logits.ndim == 0 else [float(x) for x in logits.tolist()] |
|
|
| for it, s in zip(items, rr_scores): |
| it["rerank_score"] = s |
| it["final_score"] = s |
|
|
| items = sorted(items, key=lambda x: x.get("rerank_score", -1e9), reverse=True) |
| return items[:top_n] |
|
|
|
|
| def build_context(items: List[Dict[str, Any]]) -> Tuple[str, List[Dict[str, Any]]]: |
| """ |
| New index metadata schema (minimal): |
| source_file, source_path, doc_id, file_hash, chunk_index, chunk_id |
| """ |
| citations = [] |
| parts = [] |
|
|
| for it in items: |
| meta = it.get("metadata", {}) or {} |
| parts.append(it["text"]) |
|
|
| citations.append({ |
| "source_file": meta.get("source_file"), |
| "source_path": meta.get("source_path"), |
| "doc_id": meta.get("doc_id"), |
| "file_hash": meta.get("file_hash"), |
| "chunk_index": meta.get("chunk_index"), |
| "chunk_id": meta.get("chunk_id"), |
| |
| "faiss_score": it.get("score"), |
| "rerank_score": it.get("rerank_score"), |
| "final_score": it.get("final_score"), |
| }) |
|
|
| return "\n\n---\n\n".join(parts), citations |
|
|
|
|
| def should_answer(context: str) -> bool: |
| return isinstance(context, str) and len(context.strip()) >= MIN_CONTEXT_CHARS |
|
|
|
|
| async def generate_answer_openai(question: str, context: str) -> str: |
| system = ( |
| "You are a QA assistant.\n" |
| "Use the CONVERSATION CONTEXT only to understand the user's question.\n" |
| "Answer ONLY using the provided DOCUMENT CONTEXT.\n" |
| "If the answer is not in the document context, reply exactly:\n" |
| "\"I don't know based on the provided documents\".\n" |
| "Be concise and factual." |
| ) |
|
|
| user = f"QUESTION:\n{question}\n\nCONTEXT:\n{context}" |
|
|
| resp = await _openai_client.chat.completions.create( |
| model=os.getenv("OPENAI_MODEL", "gpt-4.1-mini"), |
| messages=[ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": user}, |
| ], |
| ) |
| return resp.choices[0].message.content.strip() |
|
|
|
|
| async def generate_answer_together(question: str, context: str) -> str: |
| from llama_index.llms.together import TogetherLLM |
|
|
| llm = TogetherLLM( |
| model=os.getenv("TOGETHER_MODEL", "openai/gpt-oss-20b"), |
| api_key=os.environ.get("TOGETHER_API_KEY"), |
| streaming=False, |
| ) |
|
|
| prompt = ( |
| "Answer ONLY using the CONTEXT below.\n" |
| "If you cannot find the answer in the context, reply exactly:\n" |
| "\"I don't know based on the provided documents\".\n\n" |
| f"QUESTION:\n{question}\n\nCONTEXT:\n{context}\n" |
| ) |
| out = llm.complete(prompt) |
| return (out.text or "").strip() |
|
|
|
|
| async def answer_question(question: str) -> Dict[str, Any]: |
| t0 = time.perf_counter() |
| retrieved = retrieve_topk(question, k=TOP_K) |
| t1 = time.perf_counter() |
|
|
| reranked = rerank(question, retrieved, top_n=RERANK_TOP_N) |
| t2 = time.perf_counter() |
|
|
| context, citations = build_context(reranked) |
|
|
| if not should_answer(context): |
| t3 = time.perf_counter() |
| print( |
| f"[timing] retrieve={t1 - t0:.2f}s " |
| f"rerank={t2 - t1:.2f}s " |
| f"llm=0.00s total={t3 - t0:.2f}s (refusal)" |
| ) |
| return { |
| "answer": "I don't know based on the provided documents", |
| "citations": citations, |
| "used_chunks": reranked, |
| } |
|
|
| answer = await (generate_answer_openai(question, context) if USE_OPENAI else generate_answer_together(question, context)) |
| t3 = time.perf_counter() |
|
|
| print( |
| f"[timing] retrieve={t1 - t0:.2f}s " |
| f"rerank={t2 - t1:.2f}s " |
| f"llm={t3 - t2:.2f}s total={t3 - t0:.2f}s" |
| ) |
|
|
| return { |
| "answer": answer, |
| "citations": citations, |
| "used_chunks": reranked, |
| } |