sadegh803211 commited on
Commit
11cbf4f
·
verified ·
1 Parent(s): 1d03178

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -170
app.py CHANGED
@@ -11,152 +11,158 @@ import re
11
  import hashlib
12
  from typing import List, Tuple, Optional
13
 
14
-
15
-
16
  import numpy as np
17
  import pandas as pd
18
  import torch
19
  import faiss
20
  from PIL import Image, ImageOps
21
 
22
-
23
-
24
  # Hugging Face Transformers & Sentence-Transformers
25
-
26
- from transformers import (CLIPVisionModel,CLIPImageProcessor,AutoTokenizer,AutoModel)
27
-
28
  from sentence_transformers import SentenceTransformer
 
29
  # Google Generative AI
30
  import google.generativeai as genai
31
  from google.generativeai.types import GenerationConfig
 
32
  # Gradio for Web UI
33
  import gradio as gr
34
 
35
  # --- CONFIGURATION CLASS ---
36
  class Config:
37
-     per_option_ctx: int = 5
38
-     max_text_len: int = 512
39
-     docstore_path: str = "indexes/docstore.parquet"
40
-     glot_model_hf: str = "Arshiaizd/Glot500-FineTuned"
41
-     mclip_text_model_hf: str = "Arshiaizd/MCLIP_FA_FineTuned"
42
-     clip_vision_model: str = "SajjadAyoubi/clip-fa-vision"
43
-     glot_index_out: str = "indexes/I_glot_text_fa.index"
44
-     clip_index_out: str = "indexes/I_clip_text_fa.index"
 
45
  # --- UTILITY CLASS ---
46
  class Utils:
47
-     @staticmethod
48
-     def build_context_block(hits: List[Tuple[int, float]], docstore: pd.DataFrame, count: int, max_chars=350) -> str:
49
-         if not hits or docstore.empty:
50
-             return "No relevant documents found."
51
-         lines = []
52
-         # Ensure we don't try to access indices that are out of bounds
53
-         valid_hits = [h for h in hits if h[0] < len(docstore)]
54
-         for i, score in valid_hits[:count]:
55
-             row = docstore.iloc[i]
56
-             # Ensure 'passage_text' and 'id' columns exist
57
-             txt = str(row.get("passage_text", "Text not available"))
58
-             doc_id = row.get("id", "N/A")
59
-             txt = (txt[:max_chars] + "…") if len(txt) > max_chars else txt
60
-             lines.append(f"- [doc:{doc_id}] {txt}")
61
-         return "\n".join(lines)
62
 
63
  # --- ENCODER CLASSES ---
64
  class Glot500Encoder:
65
-     def __init__(self, model_id: str):
66
-         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
-         self.st_model = SentenceTransformer(model_id, device=str(self.device))
68
-         print(f"Glot-500 model '{model_id}' loaded successfully.")
69
-     @torch.no_grad()
70
-     def encode(self, texts: List[str], batch_size: int = 64) -> np.ndarray:
71
-         return self.st_model.encode(texts, batch_size=batch_size, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
 
 
 
 
72
 
73
  class FaTextEncoder:
74
-     def __init__(self, model_id: str, device: torch.device, max_len: int):
75
-         self.device, self.max_len = device, max_len
76
-         self.tok = AutoTokenizer.from_pretrained(model_id)
77
-         self.model = AutoModel.from_pretrained(model_id).to(device).eval()
78
-         print(f"FaCLIP text model '{model_id}' loaded successfully.")
79
-
80
-     @torch.no_grad()
81
-     def encode_numpy(self, texts: List[str], batch_size: int = 128) -> np.ndarray:
82
-         vecs = []
83
-         for i in range(0, len(texts), batch_size):
84
-             toks = self.tok(texts[i:i+batch_size], padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
85
-             out = self.model(**toks)
86
-             x = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else (out.last_hidden_state * toks.attention_mask.unsqueeze(-1)).sum(1) / toks.attention_mask.sum(1).clamp(min=1)
87
-             x_norm = x / x.norm(p=2, dim=1, keepdim=True)
88
-             vecs.append(x_norm.detach().cpu().numpy())
89
-         return np.vstack(vecs).astype(np.float32)
 
 
 
 
 
 
 
90
 
91
  class FaVisionEncoder:
92
-     def __init__(self, model_id: str, device: torch.device):
93
-         self.device = device
94
-         self.model = CLIPVisionModel.from_pretrained(model_id).to(device).eval()
95
-         self.proc = CLIPImageProcessor.from_pretrained(model_id)
96
-     @torch.no_grad()
97
-     def encode(self, img: Image.Image) -> np.ndarray:
98
-         img = ImageOps.exif_transpose(img).convert("RGB")
99
-         batch = self.proc(images=img, return_tensors="pt").to(self.device)
100
-         out = self.model(**batch)
101
-         v = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:,0]
102
-         v_norm = v / v.norm(p=2, dim=1, keepdim=True)
103
-         return v_norm[0].detach().cpu().numpy().astype(np.float32)
 
104
 
105
  # --- RETRIEVER CLASSES ---
106
  class BaseRetriever:
107
-     def __init__(self, docstore: pd.DataFrame, index_path: str):
108
-         self.docstore, self.index_path = docstore.reset_index(drop=True), index_path
109
-         if os.path.isfile(self.index_path):
110
-             self.index = faiss.read_index(self.index_path)
111
-         else:
112
-             raise FileNotFoundError(f"Index file not found at {self.index_path}. Make sure it's uploaded to your Space.")
113
 
114
-     def search(self, query_vec: np.ndarray, k: int) -> List[Tuple[int, float]]:
115
-         D, I = self.index.search(query_vec[None, :].astype(np.float32), k)
116
-         return list(zip(I[0].tolist(), D[0].tolist()))
117
 
118
  class Glot500Retriever(BaseRetriever):
119
-     def __init__(self, encoder: Glot500Encoder, docstore: pd.DataFrame, index_path: str):
120
-         super().__init__(docstore, index_path)
121
-         self.encoder = encoder
122
-     def topk(self, query: str, k: int) -> List[Tuple[int, float]]:
123
-         qv = self.encoder.encode([query], batch_size=1)[0]
124
-         return self.search(qv, k)
125
-
126
- class TextIndexRetriever(BaseRetriever):
127
-     def __init__(self, text_encoder: FaTextEncoder, docstore: pd.DataFrame, index_path: str):
128
-         super().__init__(docstore, index_path)
129
-         self.encoder = text_encoder
130
 
 
 
 
131
 
 
 
 
 
132
 
133
  # --- GENERATION AND SYSTEM CLASSES ---
134
-
135
  class VLM_GenAI:
136
-     def __init__(self, api_key: str, model_name: str, temperature: float = 0.1, max_output_tokens: int = 1024):
137
-         if not api_key or "YOUR" in api_key:
138
-             raise ValueError("Gemini API Key is missing or is a placeholder. Please add it to your Hugging Face Space secrets.")
139
-         genai.configure(api_key=api_key)
140
-         self.model = genai.GenerativeModel(model_name)
141
-         self.generation_config = GenerationConfig(temperature=temperature, max_output_tokens=max_output_tokens)
142
-         self.safety_settings = {
143
-             "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
144
-             "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
145
-         }
 
 
146
 
147
  class RAGSystem:
148
-     def __init__(self, cfg: Config):
149
-         self.docstore = pd.read_parquet(cfg.docstore_path)
150
-         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
151
-         self.glot_enc = Glot500Encoder(cfg.glot_model_hf)
152
-         self.glot_ret = Glot500Retriever(self.glot_enc, self.docstore, cfg.glot_index_out)
153
-
154
-
155
-         txt_enc = FaTextEncoder(cfg.mclip_text_model_hf, device, cfg.max_text_len)
156
-         self.mclip_ret = TextIndexRetriever(txt_enc, self.docstore, cfg.clip_index_out)
157
-         self.vision = FaVisionEncoder(cfg.clip_vision_model, device)
158
-
159
 
 
 
 
160
 
161
  # ==============================================================================
162
 
@@ -164,10 +170,7 @@ class RAGSystem:
164
 
165
  # ==============================================================================
166
 
167
-
168
-
169
  # --- 1. LOAD MODELS AND INDEXES (This runs only once when the app starts) ---
170
-
171
  print("Initializing configuration...")
172
  cfg = Config()
173
  print("Loading RAG system (models, encoders, and retrievers)...")
@@ -179,68 +182,67 @@ print("System ready.")
179
 
180
  # --- 2. DEFINE THE FUNCTION TO HANDLE USER INPUT ---
181
  def run_rag_query(question_text: str, question_image: Optional[Image.Image]) -> Tuple[str, str]:
182
-     if not question_text.strip():
183
-         return "Please ask a question.", ""
184
-     context_block = ""
185
-     # Decide which retriever to use based on input
186
-     if question_image:
187
-         print("Performing multimodal retrieval...")
188
-         img_vec = rag_system.vision.encode(question_image)
189
-         hits = rag_system.mclip_ret.search(img_vec, k=cfg.per_option_ctx)
190
-         context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
191
-     else:
192
-         print("Performing text retrieval...")
193
-         hits = rag_system.glot_ret.topk(question_text, k=cfg.per_option_ctx)
194
-         context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
195
-
196
-     # --- Augment and Generate ---
197
-     print("Generating response...")
198
-     if question_image:
199
-         prompt = f"با توجه به تصویر و اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
200
-     else:
201
-         prompt = f"با توجه به اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
202
-
203
-     content_parts = [question_image, prompt] if question_image else [prompt]    
204
-
205
-     try:
206
-         resp = vlm.model.generate_content(
207
-             content_parts,
208
-             generation_config=vlm.generation_config,
209
-             safety_settings=vlm.safety_settings
210
-         )
211
-         answer = resp.text
212
-
213
-     except Exception as e:
214
-         answer = f"Error during generation: {e}"
215
-         print(answer)
216
-     return answer, context_block
217
 
218
  # --- 3. CREATE THE GRADIO INTERFACE ---
219
  with gr.Blocks(theme=gr.themes.Soft(), title="Persian Culinary RAG") as demo:
220
-     gr.Markdown("# 🍲 Persian Culinary RAG Demo")
221
-     gr.Markdown("Ask a question about Iranian food, with or without an image, to see the RAG system in action.")
222
-     with gr.Row():
223
-         with gr.Column(scale=1):
224
-             image_input = gr.Image(type="pil", label="Upload an Image (Optional)")
225
-             text_input = gr.Textbox(label="Ask your question in Persian", placeholder="...مثلا: در مورد قورمه سبزی توضیح بده")
226
-             submit_button = gr.Button("Submit", variant="primary")
227
-         with gr.Column(scale=2):
228
-             output_answer = gr.Textbox(label="Answer from Model", lines=8, interactive=False)
229
-             output_context = gr.Textbox(label="Retrieved Context (What the model used to answer)", lines=12, interactive=False)
230
-
231
-     gr.Examples(
232
-         examples=[
233
-             ["در مورد حلوا توضیح بده", None],
234
-             ["مواد لازم برای تهیه آش رشته چیست؟", None],
235
-         ],
236
-         inputs=[text_input, image_input]
237
-     )
238
-
239
-     submit_button.click(
240
-         fn=run_rag_query,
241
-         inputs=[text_input, image_input],
242
-         outputs=[output_answer, output_context]
243
-     )
244
 
245
  # Launch the web server
246
- demo.launch()
 
11
  import hashlib
12
  from typing import List, Tuple, Optional
13
 
 
 
14
  import numpy as np
15
  import pandas as pd
16
  import torch
17
  import faiss
18
  from PIL import Image, ImageOps
19
 
 
 
20
  # Hugging Face Transformers & Sentence-Transformers
21
+ from transformers import (CLIPVisionModel, CLIPImageProcessor, AutoTokenizer, AutoModel)
 
 
22
  from sentence_transformers import SentenceTransformer
23
+
24
  # Google Generative AI
25
  import google.generativeai as genai
26
  from google.generativeai.types import GenerationConfig
27
+
28
  # Gradio for Web UI
29
  import gradio as gr
30
 
31
  # --- CONFIGURATION CLASS ---
32
  class Config:
33
+ per_option_ctx: int = 5
34
+ max_text_len: int = 512
35
+ docstore_path: str = "indexes/docstore.parquet"
36
+ glot_model_hf: str = "Arshiaizd/Glot500-FineTuned"
37
+ mclip_text_model_hf: str = "Arshiaizd/MCLIP_FA_FineTuned"
38
+ clip_vision_model: str = "SajjadAyoubi/clip-fa-vision"
39
+ glot_index_out: str = "indexes/I_glot_text_fa.index"
40
+ clip_index_out: str = "indexes/I_clip_text_fa.index"
41
+
42
  # --- UTILITY CLASS ---
43
  class Utils:
44
+ @staticmethod
45
+ def build_context_block(hits: List[Tuple[int, float]], docstore: pd.DataFrame, count: int, max_chars=350) -> str:
46
+ if not hits or docstore.empty:
47
+ return "No relevant documents found."
48
+ lines = []
49
+ # Ensure we don't try to access indices that are out of bounds
50
+ valid_hits = [h for h in hits if h[0] < len(docstore)]
51
+ for i, score in valid_hits[:count]:
52
+ row = docstore.iloc[i]
53
+ # Ensure 'passage_text' and 'id' columns exist
54
+ txt = str(row.get("passage_text", "Text not available"))
55
+ doc_id = row.get("id", "N/A")
56
+ txt = (txt[:max_chars] + "…") if len(txt) > max_chars else txt
57
+ lines.append(f"- [doc:{doc_id}] {txt}")
58
+ return "\n".join(lines)
59
 
60
  # --- ENCODER CLASSES ---
61
  class Glot500Encoder:
62
+ def __init__(self, model_id: str):
63
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ self.st_model = SentenceTransformer(model_id, device=str(self.device))
65
+ print(f"Glot-500 model '{model_id}' loaded successfully.")
66
+
67
+ @torch.no_grad()
68
+ def encode(self, texts: List[str], batch_size: int = 64) -> np.ndarray:
69
+ return self.st_model.encode(
70
+ texts, batch_size=batch_size, show_progress_bar=False,
71
+ convert_to_numpy=True, normalize_embeddings=True
72
+ ).astype(np.float32)
73
 
74
  class FaTextEncoder:
75
+ def __init__(self, model_id: str, device: torch.device, max_len: int):
76
+ self.device, self.max_len = device, max_len
77
+ self.tok = AutoTokenizer.from_pretrained(model_id)
78
+ self.model = AutoModel.from_pretrained(model_id).to(device).eval()
79
+ print(f"FaCLIP text model '{model_id}' loaded successfully.")
80
+
81
+ @torch.no_grad()
82
+ def encode_numpy(self, texts: List[str], batch_size: int = 128) -> np.ndarray:
83
+ vecs = []
84
+ for i in range(0, len(texts), batch_size):
85
+ toks = self.tok(
86
+ texts[i:i+batch_size], padding=True, truncation=True,
87
+ max_length=self.max_len, return_tensors="pt"
88
+ ).to(self.device)
89
+ out = self.model(**toks)
90
+ x = (
91
+ out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None
92
+ else (out.last_hidden_state * toks.attention_mask.unsqueeze(-1)).sum(1)
93
+ / toks.attention_mask.sum(1).clamp(min=1)
94
+ )
95
+ x_norm = x / x.norm(p=2, dim=1, keepdim=True)
96
+ vecs.append(x_norm.detach().cpu().numpy())
97
+ return np.vstack(vecs).astype(np.float32)
98
 
99
  class FaVisionEncoder:
100
+ def __init__(self, model_id: str, device: torch.device):
101
+ self.device = device
102
+ self.model = CLIPVisionModel.from_pretrained(model_id).to(device).eval()
103
+ self.proc = CLIPImageProcessor.from_pretrained(model_id)
104
+
105
+ @torch.no_grad()
106
+ def encode(self, img: Image.Image) -> np.ndarray:
107
+ img = ImageOps.exif_transpose(img).convert("RGB")
108
+ batch = self.proc(images=img, return_tensors="pt").to(self.device)
109
+ out = self.model(**batch)
110
+ v = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:, 0]
111
+ v_norm = v / v.norm(p=2, dim=1, keepdim=True)
112
+ return v_norm[0].detach().cpu().numpy().astype(np.float32)
113
 
114
  # --- RETRIEVER CLASSES ---
115
  class BaseRetriever:
116
+ def __init__(self, docstore: pd.DataFrame, index_path: str):
117
+ self.docstore, self.index_path = docstore.reset_index(drop=True), index_path
118
+ if os.path.isfile(self.index_path):
119
+ self.index = faiss.read_index(self.index_path)
120
+ else:
121
+ raise FileNotFoundError(f"Index file not found at {self.index_path}. Make sure it's uploaded to your Space.")
122
 
123
+ def search(self, query_vec: np.ndarray, k: int) -> List[Tuple[int, float]]:
124
+ D, I = self.index.search(query_vec[None, :].astype(np.float32), k)
125
+ return list(zip(I[0].tolist(), D[0].tolist()))
126
 
127
  class Glot500Retriever(BaseRetriever):
128
+ def __init__(self, encoder: Glot500Encoder, docstore: pd.DataFrame, index_path: str):
129
+ super().__init__(docstore, index_path)
130
+ self.encoder = encoder
 
 
 
 
 
 
 
 
131
 
132
+ def topk(self, query: str, k: int) -> List[Tuple[int, float]]:
133
+ qv = self.encoder.encode([query], batch_size=1)[0]
134
+ return self.search(qv, k)
135
 
136
+ class TextIndexRetriever(BaseRetriever):
137
+ def __init__(self, text_encoder: FaTextEncoder, docstore: pd.DataFrame, index_path: str):
138
+ super().__init__(docstore, index_path)
139
+ self.encoder = text_encoder
140
 
141
  # --- GENERATION AND SYSTEM CLASSES ---
 
142
  class VLM_GenAI:
143
+ def __init__(self, api_key: str, model_name: str, temperature: float = 0.1, max_output_tokens: int = 1024):
144
+ if not api_key or "YOUR" in api_key:
145
+ raise ValueError("Gemini API Key is missing or is a placeholder. Please add it to your Hugging Face Space secrets.")
146
+ genai.configure(api_key=api_key)
147
+ self.model = genai.GenerativeModel(model_name)
148
+ self.generation_config = GenerationConfig(temperature=temperature, max_output_tokens=max_output_tokens)
149
+ self.safety_settings = {
150
+ "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
151
+ "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
152
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
153
+ "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
154
+ }
155
 
156
  class RAGSystem:
157
+ def __init__(self, cfg: Config):
158
+ self.docstore = pd.read_parquet(cfg.docstore_path)
159
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160
+ self.glot_enc = Glot500Encoder(cfg.glot_model_hf)
161
+ self.glot_ret = Glot500Retriever(self.glot_enc, self.docstore, cfg.glot_index_out)
 
 
 
 
 
 
162
 
163
+ txt_enc = FaTextEncoder(cfg.mclip_text_model_hf, device, cfg.max_text_len)
164
+ self.mclip_ret = TextIndexRetriever(txt_enc, self.docstore, cfg.clip_index_out)
165
+ self.vision = FaVisionEncoder(cfg.clip_vision_model, device)
166
 
167
  # ==============================================================================
168
 
 
170
 
171
  # ==============================================================================
172
 
 
 
173
  # --- 1. LOAD MODELS AND INDEXES (This runs only once when the app starts) ---
 
174
  print("Initializing configuration...")
175
  cfg = Config()
176
  print("Loading RAG system (models, encoders, and retrievers)...")
 
182
 
183
  # --- 2. DEFINE THE FUNCTION TO HANDLE USER INPUT ---
184
  def run_rag_query(question_text: str, question_image: Optional[Image.Image]) -> Tuple[str, str]:
185
+ if not question_text.strip():
186
+ return "Please ask a question.", ""
187
+ context_block = ""
188
+ # Decide which retriever to use based on input
189
+ if question_image:
190
+ print("Performing multimodal retrieval...")
191
+ img_vec = rag_system.vision.encode(question_image)
192
+ hits = rag_system.mclip_ret.search(img_vec, k=cfg.per_option_ctx)
193
+ context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
194
+ else:
195
+ print("Performing text retrieval...")
196
+ hits = rag_system.glot_ret.topk(question_text, k=cfg.per_option_ctx)
197
+ context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
198
+
199
+ # --- Augment and Generate ---
200
+ print("Generating response...")
201
+ if question_image:
202
+ prompt = f"با توجه به تصویر و اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
203
+ else:
204
+ prompt = f"با توجه به اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
205
+
206
+ content_parts = [question_image, prompt] if question_image else [prompt]
207
+
208
+ try:
209
+ resp = vlm.model.generate_content(
210
+ content_parts,
211
+ generation_config=vlm.generation_config,
212
+ safety_settings=vlm.safety_settings
213
+ )
214
+ answer = resp.text
215
+ except Exception as e:
216
+ answer = f"Error during generation: {e}"
217
+ print(answer)
218
+ return answer, context_block
 
219
 
220
  # --- 3. CREATE THE GRADIO INTERFACE ---
221
  with gr.Blocks(theme=gr.themes.Soft(), title="Persian Culinary RAG") as demo:
222
+ gr.Markdown("# 🍲 Persian Culinary RAG Demo")
223
+ gr.Markdown("Ask a question about Iranian food, with or without an image, to see the RAG system in action.")
224
+ with gr.Row():
225
+ with gr.Column(scale=1):
226
+ image_input = gr.Image(type="pil", label="Upload an Image (Optional)")
227
+ text_input = gr.Textbox(label="Ask your question in Persian", placeholder="...مثلا: در مورد قورمه سبزی توضیح بده")
228
+ submit_button = gr.Button("Submit", variant="primary")
229
+ with gr.Column(scale=2):
230
+ output_answer = gr.Textbox(label="Answer from Model", lines=8, interactive=False)
231
+ output_context = gr.Textbox(label="Retrieved Context (What the model used to answer)", lines=12, interactive=False)
232
+
233
+ gr.Examples(
234
+ examples=[
235
+ ["در مورد حلوا توضیح بده", None],
236
+ ["مواد لازم برای تهیه آش رشته چیست؟", None],
237
+ ],
238
+ inputs=[text_input, image_input]
239
+ )
240
+
241
+ submit_button.click(
242
+ fn=run_rag_query,
243
+ inputs=[text_input, image_input],
244
+ outputs=[output_answer, output_context]
245
+ )
246
 
247
  # Launch the web server
248
+ demo.launch()