import os import json import pickle import threading import numpy as np import faiss from sentence_transformers import SentenceTransformer from rank_bm25 import BM25Okapi import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer, ) import gradio as gr # ---------------------------- # Config (match your notebook) # ---------------------------- EMBED_MODEL_NAME = "intfloat/multilingual-e5-large" # notebook uses this:contentReference[oaicite:4]{index=4} LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" # notebook uses this:contentReference[oaicite:5]{index=5} CHUNKS_PATH = "sharif_rules_chunked.json" FAISS_PATH = "vector_index.faiss" # pickle-dumped faiss index in notebook:contentReference[oaicite:6]{index=6} BM25_PATH = "bm25_index.pkl" # pickle-dumped bm25 in notebook:contentReference[oaicite:7]{index=7} # You used k up to 6 in the UI in notebook DEFAULT_K = 3 DEFAULT_MAX_CTX_CHARS = 1200 # ---------------------------- # Load artifacts # ---------------------------- def load_artifacts(): if not os.path.exists(CHUNKS_PATH): raise FileNotFoundError( f"Missing {CHUNKS_PATH}. Upload it to the Space repo (recommended), " "or add code to build it at startup." ) if not os.path.exists(FAISS_PATH) or not os.path.exists(BM25_PATH): raise FileNotFoundError( f"Missing {FAISS_PATH} and/or {BM25_PATH}. Upload them to the Space repo." ) with open(CHUNKS_PATH, "r", encoding="utf-8") as f: chunks = json.load(f) with open(FAISS_PATH, "rb") as f: vector_index = pickle.load(f) with open(BM25_PATH, "rb") as f: bm25 = pickle.load(f) return chunks, vector_index, bm25 print("Loading embedding model...") embed_model = SentenceTransformer(EMBED_MODEL_NAME) print("Loading retrieval artifacts...") chunks, vector_index, bm25 = load_artifacts() print("Loading LLM + tokenizer...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", ) tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( LLM_MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) model.eval() print("All models loaded.") # ---------------------------- # Retrieval (match notebook) # ---------------------------- def hybrid_search(query: str, k: int = 5): """ Hybrid Search (Vector + BM25) with Reciprocal Rank Fusion, same logic as notebook. """ # 1) Vector search query_embedding = embed_model.encode([query], normalize_embeddings=True) v_scores, v_indices = vector_index.search(query_embedding, k) # 2) BM25 search tokenized_query = query.split() bm25_scores = bm25.get_scores(tokenized_query) bm25_indices = np.argsort(bm25_scores)[::-1][:k] # 3) RRF fusion fusion_scores = {} for rank, idx in enumerate(v_indices[0]): fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60) for rank, idx in enumerate(bm25_indices): fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60) sorted_indices = sorted(fusion_scores, key=fusion_scores.get, reverse=True)[:k] return [chunks[i] for i in sorted_indices] # ---------------------------- # Prompt + generation # ---------------------------- SYSTEM_PROMPT_FA = """شما یک دستیار هوشمند آموزشی برای دانشگاه صنعتی شریف هستید. وظیفه شما پاسخ‌دهی دقیق به سوالات دانشجو بر اساس "متن قوانین" زیر است. قوانین مهم: 1. فقط و فقط از اطلاعات موجود در بخش [Context] استفاده کنید. از دانش قبلی خود استفاده نکنید. 2. اگر پاسخ سوال در متن موجود نیست، دقیقاً بگویید: "اطلاعاتی در این مورد در آیین‌نامه‌های موجود یافت نشد." 3. پاسخ نهایی باید کاملاً به زبان فارسی باشد. 4. نام آیین‌نامه و شماره ماده یا تبصره را در پاسخ ذکر کنید. """ def build_context_text(retrieved_chunks, max_ctx_chars: int): context_text = "" for i, chunk in enumerate(retrieved_chunks): # Your notebook stores metadata in chunk["metadata"] with title/article:contentReference[oaicite:8]{index=8}:contentReference[oaicite:9]{index=9} md = chunk.get("metadata", {}) or {} source = md.get("title", "Unknown") article = md.get("article", "N/A") txt = (chunk.get("text", "") or "").strip() txt = txt[: int(max_ctx_chars)] context_text += f"Document {i+1} (Source: {source}, Article: {article}):\n{txt}\n\n" return context_text def generate_answer_stream(query: str, retrieved_chunks, max_ctx_chars: int = 1200): """ True token streaming with TextIteratorStreamer. Yields partial strings (the growing answer). """ context_text = build_context_text(retrieved_chunks, max_ctx_chars=max_ctx_chars) user_prompt = f"""سوال: {query} [Context]: {context_text} پاسخ:""" messages = [ {"role": "system", "content": SYSTEM_PROMPT_FA}, {"role": "user", "content": user_prompt}, ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_special_tokens=True, # keep prompt out of the stream (we only want the assistant answer) skip_prompt=True, ) gen_kwargs = dict( **model_inputs, max_new_tokens=512, temperature=0.1, top_p=0.9, streamer=streamer, ) thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) thread.start() partial = "" for token_text in streamer: partial += token_text yield partial thread.join() # ---------------------------- # UI helpers (match your demo) # ---------------------------- def format_sources(retrieved_docs, max_chars=300): lines = [] for i, d in enumerate(retrieved_docs, 1): md = d.get("metadata", {}) or {} title = md.get("title", "") src = md.get("source", "") art = md.get("article", "-") snippet = (d.get("text", "") or "").strip().replace("\n", " ") snippet = snippet[:max_chars] + ("…" if len(snippet) > max_chars else "") lines.append(f"{i}. {title}\n source: {src} | ماده: {art}\n snippet: {snippet}") return "\n\n".join(lines) def rag_answer_ui_stream(question, k, max_ctx_chars): if not question or not question.strip(): yield "لطفاً سوال را وارد کنید.", "" return # 1) Retrieve retrieved = hybrid_search(question, k=int(k)) if not retrieved: yield "اطلاعاتی در این مورد در آیین‌نامه‌های موجود یافت نشد.", "" return # 2) Prepare sources (static; we keep showing it while streaming) sources_text = format_sources(retrieved) # 3) Stream answer for partial_answer in generate_answer_stream( question, retrieved, max_ctx_chars=int(max_ctx_chars), ): yield partial_answer, sources_text with gr.Blocks(title="Sharif RAG Demo (Streaming)") as demo: gr.Markdown( "## 🎓 Sharif Regulations RAG Bot (Streaming)\n" "سوال خود را وارد کنید. پاسخ فقط بر اساس متن‌های بازیابی‌شده تولید می‌شود." ) with gr.Row(): question = gr.Textbox( label="❓ Question (Persian)", placeholder="مثلاً: شرایط مهمانی در دوره روزانه؟", lines=2, ) with gr.Row(): k = gr.Slider(1, 6, value=DEFAULT_K, step=1, label="🔎 Number of retrieved chunks (k)") max_ctx_chars = gr.Slider(300, 2500, value=DEFAULT_MAX_CTX_CHARS, step=100, label="✂️ Max chars per chunk (for generation)") run_btn = gr.Button("Run RAG (stream)") answer_out = gr.Textbox(label="🤖 Answer (streaming)", lines=10) sources_out = gr.Textbox(label="📚 Retrieved sources (debug)", lines=12) run_btn.click( fn=rag_answer_ui_stream, inputs=[question, k, max_ctx_chars], outputs=[answer_out, sources_out], ) # Spaces will call app.py; server_name makes it work in containers too demo.queue().launch(server_name="0.0.0.0", server_port=7860)