aminhalvaei commited on
Commit
7b05361
·
verified ·
1 Parent(s): b56e02e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ import threading
5
+
6
+ import numpy as np
7
+ import faiss
8
+ from sentence_transformers import SentenceTransformer
9
+ from rank_bm25 import BM25Okapi
10
+
11
+ import torch
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ BitsAndBytesConfig,
16
+ TextIteratorStreamer,
17
+ )
18
+
19
+ import gradio as gr
20
+
21
+
22
+ # ----------------------------
23
+ # Config (match your notebook)
24
+ # ----------------------------
25
+ EMBED_MODEL_NAME = "intfloat/multilingual-e5-large" # notebook uses this:contentReference[oaicite:4]{index=4}
26
+ LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" # notebook uses this:contentReference[oaicite:5]{index=5}
27
+
28
+ CHUNKS_PATH = "sharif_rules_chunked.json"
29
+ FAISS_PATH = "vector_index.faiss" # pickle-dumped faiss index in notebook:contentReference[oaicite:6]{index=6}
30
+ BM25_PATH = "bm25_index.pkl" # pickle-dumped bm25 in notebook:contentReference[oaicite:7]{index=7}
31
+
32
+ # You used k up to 6 in the UI in notebook
33
+ DEFAULT_K = 3
34
+ DEFAULT_MAX_CTX_CHARS = 1200
35
+
36
+
37
+ # ----------------------------
38
+ # Load artifacts
39
+ # ----------------------------
40
+ def load_artifacts():
41
+ if not os.path.exists(CHUNKS_PATH):
42
+ raise FileNotFoundError(
43
+ f"Missing {CHUNKS_PATH}. Upload it to the Space repo (recommended), "
44
+ "or add code to build it at startup."
45
+ )
46
+ if not os.path.exists(FAISS_PATH) or not os.path.exists(BM25_PATH):
47
+ raise FileNotFoundError(
48
+ f"Missing {FAISS_PATH} and/or {BM25_PATH}. Upload them to the Space repo."
49
+ )
50
+
51
+ with open(CHUNKS_PATH, "r", encoding="utf-8") as f:
52
+ chunks = json.load(f)
53
+
54
+ with open(FAISS_PATH, "rb") as f:
55
+ vector_index = pickle.load(f)
56
+
57
+ with open(BM25_PATH, "rb") as f:
58
+ bm25 = pickle.load(f)
59
+
60
+ return chunks, vector_index, bm25
61
+
62
+
63
+ print("Loading embedding model...")
64
+ embed_model = SentenceTransformer(EMBED_MODEL_NAME)
65
+
66
+ print("Loading retrieval artifacts...")
67
+ chunks, vector_index, bm25 = load_artifacts()
68
+
69
+ print("Loading LLM + tokenizer...")
70
+ bnb_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_compute_dtype=torch.float16,
73
+ bnb_4bit_quant_type="nf4",
74
+ )
75
+
76
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ LLM_MODEL_NAME,
79
+ quantization_config=bnb_config,
80
+ device_map="auto",
81
+ trust_remote_code=True,
82
+ )
83
+ model.eval()
84
+ print("All models loaded.")
85
+
86
+
87
+ # ----------------------------
88
+ # Retrieval (match notebook)
89
+ # ----------------------------
90
+ def hybrid_search(query: str, k: int = 5):
91
+ """
92
+ Hybrid Search (Vector + BM25) with Reciprocal Rank Fusion, same logic as notebook.
93
+ """
94
+ # 1) Vector search
95
+ query_embedding = embed_model.encode([query], normalize_embeddings=True)
96
+ v_scores, v_indices = vector_index.search(query_embedding, k)
97
+
98
+ # 2) BM25 search
99
+ tokenized_query = query.split()
100
+ bm25_scores = bm25.get_scores(tokenized_query)
101
+ bm25_indices = np.argsort(bm25_scores)[::-1][:k]
102
+
103
+ # 3) RRF fusion
104
+ fusion_scores = {}
105
+
106
+ for rank, idx in enumerate(v_indices[0]):
107
+ fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60)
108
+
109
+ for rank, idx in enumerate(bm25_indices):
110
+ fusion_scores[idx] = fusion_scores.get(idx, 0) + 1 / (rank + 60)
111
+
112
+ sorted_indices = sorted(fusion_scores, key=fusion_scores.get, reverse=True)[:k]
113
+ return [chunks[i] for i in sorted_indices]
114
+
115
+
116
+ # ----------------------------
117
+ # Prompt + generation
118
+ # ----------------------------
119
+ SYSTEM_PROMPT_FA = """شما یک دستیار هوشمند آموزشی برای دانشگاه صنعتی شریف هستید.
120
+ وظیفه شما پاسخ‌دهی دقیق به سوالات دانشجو بر اساس "متن قوانین" زیر است.
121
+
122
+ قوانین مهم:
123
+ 1. فقط و فقط از اطلاعات موجود در بخش [Context] استفاده کنید. از دانش قبلی خود استفاده نکنید.
124
+ 2. اگر پاسخ سوال در متن موجود نیست، دقیقاً بگویید: "اطلاعاتی در این مورد در آیین‌نامه‌های موجود یافت نشد."
125
+ 3. پاسخ نهایی باید کاملاً به زبان فارسی باشد.
126
+ 4. نام آیین‌نامه و شماره ماده یا تبصره را در پاسخ ذکر کنید.
127
+ """
128
+
129
+
130
+ def build_context_text(retrieved_chunks, max_ctx_chars: int):
131
+ context_text = ""
132
+ for i, chunk in enumerate(retrieved_chunks):
133
+ # Your notebook stores metadata in chunk["metadata"] with title/article:contentReference[oaicite:8]{index=8}:contentReference[oaicite:9]{index=9}
134
+ md = chunk.get("metadata", {}) or {}
135
+ source = md.get("title", "Unknown")
136
+ article = md.get("article", "N/A")
137
+ txt = (chunk.get("text", "") or "").strip()
138
+ txt = txt[: int(max_ctx_chars)]
139
+ context_text += f"Document {i+1} (Source: {source}, Article: {article}):\n{txt}\n\n"
140
+ return context_text
141
+
142
+
143
+ def generate_answer_stream(query: str, retrieved_chunks, max_ctx_chars: int = 1200):
144
+ """
145
+ True token streaming with TextIteratorStreamer.
146
+ Yields partial strings (the growing answer).
147
+ """
148
+ context_text = build_context_text(retrieved_chunks, max_ctx_chars=max_ctx_chars)
149
+
150
+ user_prompt = f"""سوال: {query}
151
+
152
+ [Context]:
153
+ {context_text}
154
+
155
+ پاسخ:"""
156
+
157
+ messages = [
158
+ {"role": "system", "content": SYSTEM_PROMPT_FA},
159
+ {"role": "user", "content": user_prompt},
160
+ ]
161
+
162
+ text = tokenizer.apply_chat_template(
163
+ messages,
164
+ tokenize=False,
165
+ add_generation_prompt=True,
166
+ )
167
+
168
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
169
+
170
+ streamer = TextIteratorStreamer(
171
+ tokenizer,
172
+ skip_special_tokens=True,
173
+ # keep prompt out of the stream (we only want the assistant answer)
174
+ skip_prompt=True,
175
+ )
176
+
177
+ gen_kwargs = dict(
178
+ **model_inputs,
179
+ max_new_tokens=512,
180
+ temperature=0.1,
181
+ top_p=0.9,
182
+ streamer=streamer,
183
+ )
184
+
185
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
186
+ thread.start()
187
+
188
+ partial = ""
189
+ for token_text in streamer:
190
+ partial += token_text
191
+ yield partial
192
+
193
+ thread.join()
194
+
195
+
196
+ # ----------------------------
197
+ # UI helpers (match your demo)
198
+ # ----------------------------
199
+ def format_sources(retrieved_docs, max_chars=300):
200
+ lines = []
201
+ for i, d in enumerate(retrieved_docs, 1):
202
+ md = d.get("metadata", {}) or {}
203
+ title = md.get("title", "")
204
+ src = md.get("source", "")
205
+ art = md.get("article", "-")
206
+ snippet = (d.get("text", "") or "").strip().replace("\n", " ")
207
+ snippet = snippet[:max_chars] + ("…" if len(snippet) > max_chars else "")
208
+ lines.append(f"{i}. {title}\n source: {src} | ماده: {art}\n snippet: {snippet}")
209
+ return "\n\n".join(lines)
210
+
211
+
212
+ def rag_answer_ui_stream(question, k, max_ctx_chars):
213
+ if not question or not question.strip():
214
+ yield "لطفاً سوال را وارد کنید.", ""
215
+ return
216
+
217
+ # 1) Retrieve
218
+ retrieved = hybrid_search(question, k=int(k))
219
+ if not retrieved:
220
+ yield "اطلاعاتی در این مورد در آیین‌نامه‌های موجود یافت نشد.", ""
221
+ return
222
+
223
+ # 2) Prepare sources (static; we keep showing it while streaming)
224
+ sources_text = format_sources(retrieved)
225
+
226
+ # 3) Stream answer
227
+ for partial_answer in generate_answer_stream(
228
+ question,
229
+ retrieved,
230
+ max_ctx_chars=int(max_ctx_chars),
231
+ ):
232
+ yield partial_answer, sources_text
233
+
234
+
235
+ with gr.Blocks(title="Sharif RAG Demo (Streaming)") as demo:
236
+ gr.Markdown(
237
+ "## 🎓 Sharif Regulations RAG Bot (Streaming)\n"
238
+ "سوال خود را وارد کنید. پاسخ فقط بر اساس متن‌های بازیابی‌شده تولید می‌شود."
239
+ )
240
+
241
+ with gr.Row():
242
+ question = gr.Textbox(
243
+ label="❓ Question (Persian)",
244
+ placeholder="مثلاً: شرایط مهمانی در دوره روزانه؟",
245
+ lines=2,
246
+ )
247
+
248
+ with gr.Row():
249
+ k = gr.Slider(1, 6, value=DEFAULT_K, step=1, label="🔎 Number of retrieved chunks (k)")
250
+ max_ctx_chars = gr.Slider(300, 2500, value=DEFAULT_MAX_CTX_CHARS, step=100, label="✂️ Max chars per chunk (for generation)")
251
+
252
+ run_btn = gr.Button("Run RAG (stream)")
253
+ answer_out = gr.Textbox(label="🤖 Answer (streaming)", lines=10)
254
+ sources_out = gr.Textbox(label="📚 Retrieved sources (debug)", lines=12)
255
+
256
+ run_btn.click(
257
+ fn=rag_answer_ui_stream,
258
+ inputs=[question, k, max_ctx_chars],
259
+ outputs=[answer_out, sources_out],
260
+ )
261
+
262
+ # Spaces will call app.py; server_name makes it work in containers too
263
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)