sadegh803211 commited on
Commit
043aa33
·
verified ·
1 Parent(s): be3b975

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -138
app.py CHANGED
@@ -1,47 +1,82 @@
 
 
 
 
1
  import os
2
  import re
3
- import json
4
  import hashlib
5
- import unicodedata
6
- from glob import glob
7
- from typing import List, Dict, Any, Iterable
8
 
 
9
  import pandas as pd
10
- import faiss
11
  import torch
12
- import shutil
 
 
 
 
 
 
 
 
 
 
13
 
14
- # --- Important: Make sure to install the required libraries ---
15
- # pip install pandas pyarrow transformers sentence-transformers faiss-cpu
 
16
 
17
- # --- All necessary classes are included here for a self-contained script ---
 
18
 
 
 
19
  class Config:
 
 
20
  docstore_path: str = "indexes/docstore.parquet"
21
  glot_model_hf: str = "Arshiaizd/Glot500-FineTuned"
22
  mclip_text_model_hf: str = "Arshiaizd/MCLIP_FA_FineTuned"
 
23
  glot_index_out: str = "indexes/I_glot_text_fa.index"
24
  clip_index_out: str = "indexes/I_clip_text_fa.index"
25
- food_dataset_root: str = "./data/food_passages"
26
- max_text_len: int = 512
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class Glot500Encoder:
29
  def __init__(self, model_id: str):
30
- from sentence_transformers import SentenceTransformer
31
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  self.st_model = SentenceTransformer(model_id, device=str(self.device))
33
- def encode(self, texts: List[str], batch_size: int = 32) -> 'np.ndarray':
34
- import numpy as np
35
- return self.st_model.encode(texts, batch_size=batch_size, show_progress_bar=True, convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
 
36
 
37
  class FaTextEncoder:
38
  def __init__(self, model_id: str, device: torch.device, max_len: int):
39
- from transformers import AutoTokenizer, AutoModel
40
  self.device, self.max_len = device, max_len
41
  self.tok = AutoTokenizer.from_pretrained(model_id)
42
  self.model = AutoModel.from_pretrained(model_id).to(device).eval()
43
- def encode_numpy(self, texts: List[str], batch_size: int = 128) -> 'np.ndarray':
44
- import numpy as np
 
45
  vecs = []
46
  for i in range(0, len(texts), batch_size):
47
  toks = self.tok(texts[i:i+batch_size], padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
@@ -51,129 +86,152 @@ class FaTextEncoder:
51
  vecs.append(x_norm.detach().cpu().numpy())
52
  return np.vstack(vecs).astype(np.float32)
53
 
54
- class Utils:
55
- @staticmethod
56
- def _normalize_title(s: str) -> str:
57
- if s is None: return ""
58
- s = str(s).strip().replace("ي", "ی").replace("ك", "ک")
59
- s = re.sub(r"\s+", " ", s)
60
- s = re.sub(r"[^\w\u0600-\u06FF\s-]", "", s)
61
- return s.lower()
62
-
63
- @staticmethod
64
- def _iter_json_records(json_path: str) -> Iterable[Dict[str, Any]]:
65
- # This more robust version can handle both single multi-line JSON objects
66
- # and line-delimited JSON.
67
- with open(json_path, "r", encoding="utf-8") as f:
68
- txt = f.read().strip()
69
- if not txt: return
70
- try:
71
- # Try to parse the whole file as a single JSON object (list or dict)
72
- obj = json.loads(txt)
73
- if isinstance(obj, dict):
74
- yield obj
75
- return
76
- for it in obj if isinstance(obj, list) else []:
77
- if isinstance(it, dict): yield it
78
- return
79
- except json.JSONDecodeError:
80
- # If that fails, fall back to parsing line by line
81
- for line in txt.splitlines():
82
- if not (line := line.strip()): continue
83
- try:
84
- if isinstance((obj := json.loads(line)), dict): yield obj
85
- except json.JSONDecodeError:
86
- continue
87
-
88
- @staticmethod
89
- def _collect_pairs(root: str) -> pd.DataFrame:
90
- rows = []
91
- json_files = glob(os.path.join(root, "**/*.json"), recursive=True)
92
- if not json_files:
93
- print(f"Warning: No JSON files found in {root}. Please check the path.")
94
- return pd.DataFrame(rows)
95
-
96
- for jp in json_files:
97
- base_dir = os.path.dirname(jp)
98
- for rec in Utils._iter_json_records(jp):
99
- title, resp, img_rel = rec.get("title"), rec.get("response"), rec.get("image_path")
100
- if not all([title, resp, img_rel]): continue
101
- img_abs = os.path.normpath(os.path.join(base_dir, img_rel))
102
- if not os.path.isfile(img_abs): continue
103
- rows.append({"title": str(title), "text": str(resp)})
104
- return pd.DataFrame(rows)
105
-
106
- @staticmethod
107
- def _build_docstore(df: pd.DataFrame) -> pd.DataFrame:
108
- def _mk_id(row_text):
109
- return hashlib.sha1(row_text.encode("utf-8")).hexdigest()[:16]
110
- # Check if the dataframe is empty before proceeding
111
- if 'text' not in df.columns:
112
- return pd.DataFrame(columns=['id', 'passage_text', 'title']) # Return empty docstore
113
- df['id'] = df['text'].apply(_mk_id)
114
- return df.rename(columns={'text': 'passage_text'})
115
-
116
- @staticmethod
117
- def prep_dataset(root: str, out_docstore: str):
118
- print("Building docstore from source JSONs...")
119
- os.makedirs(os.path.dirname(out_docstore), exist_ok=True)
120
- df = Utils._collect_pairs(root)
121
- print(f"Found {len(df)} total passages.")
122
-
123
- if df.empty:
124
- print("Warning: No valid data found to process. The docstore will be empty.")
125
- doc = Utils._build_docstore(df)
126
  else:
127
- df.drop_duplicates(subset=['text'], keep='first', inplace=True)
128
- print(f"Found {len(df)} unique passages after deduplication.")
129
- doc = Utils._build_docstore(df)
130
-
131
- doc.to_parquet(out_docstore, index=False)
132
- print(f"Docstore saved to {out_docstore}.")
133
- return doc
134
-
135
- def build_faiss_index(encoder, docstore, index_path, text_col="passage_text"):
136
- print(f"Building FAISS index: {os.path.basename(index_path)}")
137
- # Check if docstore is empty
138
- if docstore.empty:
139
- print("Docstore is empty. Skipping FAISS index creation.")
140
- return
141
-
142
- texts = docstore[text_col].astype(str).tolist()
143
- if hasattr(encoder, 'encode_numpy'):
144
- vecs = encoder.encode_numpy(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  else:
146
- vecs = encoder.encode(texts)
147
-
148
- index = faiss.IndexFlatIP(vecs.shape[1])
149
- index.add(vecs.astype('float32'))
150
- faiss.write_index(index, index_path)
151
- print("Index built and saved successfully.")
 
 
 
 
152
 
153
- def main():
154
- cfg = Config()
155
 
156
- # Clean up old indexes first
157
- if os.path.isdir("indexes"):
158
- print("Removing old 'indexes' directory...")
159
- shutil.rmtree("indexes")
160
-
161
- # 1. Create the deduplicated docstore
162
- docstore = Utils.prep_dataset(root=cfg.food_dataset_root, out_docstore=cfg.docstore_path)
163
-
164
- # 2. Build Glot index
165
- print("\n--- Building Glot Index ---")
166
- glot_encoder = Glot500Encoder(cfg.glot_model_hf)
167
- build_faiss_index(glot_encoder, docstore, cfg.glot_index_out)
168
-
169
- # 3. Build CLIP index
170
- print("\n--- Building CLIP Text Index ---")
171
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
- clip_text_encoder = FaTextEncoder(cfg.mclip_text_model_hf, device, cfg.max_text_len)
173
- build_faiss_index(clip_text_encoder, docstore, cfg.clip_index_out)
 
 
 
 
 
 
 
 
174
 
175
- print("\nAll new indexes have been created successfully!")
176
-
177
- if __name__ == "__main__":
178
- main()
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
1
+ # ==============================================================================
2
+ # Part 1: Core Classes from the Original Script
3
+ # All the necessary helper classes for the RAG system are defined here.
4
+ # ==============================================================================
5
  import os
6
  import re
 
7
  import hashlib
8
+ from typing import List, Tuple, Optional
 
 
9
 
10
+ import numpy as np
11
  import pandas as pd
 
12
  import torch
13
+ import faiss
14
+ from PIL import Image, ImageOps
15
+
16
+ # Hugging Face Transformers & Sentence-Transformers
17
+ from transformers import (
18
+ CLIPVisionModel,
19
+ CLIPImageProcessor,
20
+ AutoTokenizer,
21
+ AutoModel,
22
+ )
23
+ from sentence_transformers import SentenceTransformer
24
 
25
+ # Google Generative AI
26
+ import google.generativeai as genai
27
+ from google.generativeai.types import GenerationConfig
28
 
29
+ # Gradio for Web UI
30
+ import gradio as gr
31
 
32
+
33
+ # --- CONFIGURATION CLASS ---
34
  class Config:
35
+ per_option_ctx: int = 5
36
+ max_text_len: int = 512
37
  docstore_path: str = "indexes/docstore.parquet"
38
  glot_model_hf: str = "Arshiaizd/Glot500-FineTuned"
39
  mclip_text_model_hf: str = "Arshiaizd/MCLIP_FA_FineTuned"
40
+ clip_vision_model: str = "SajjadAyoubi/clip-fa-vision"
41
  glot_index_out: str = "indexes/I_glot_text_fa.index"
42
  clip_index_out: str = "indexes/I_clip_text_fa.index"
 
 
43
 
44
+ # --- UTILITY CLASS ---
45
+ class Utils:
46
+ @staticmethod
47
+ def build_context_block(hits: List[Tuple[int, float]], docstore: pd.DataFrame, count: int, max_chars=350) -> str:
48
+ if not hits or docstore.empty:
49
+ return "No relevant documents found."
50
+ lines = []
51
+ # Ensure we don't try to access indices that are out of bounds
52
+ valid_hits = [h for h in hits if h[0] < len(docstore)]
53
+ for i, score in valid_hits[:count]:
54
+ row = docstore.iloc[i]
55
+ # Ensure 'passage_text' and 'id' columns exist
56
+ txt = str(row.get("passage_text", "Text not available"))
57
+ doc_id = row.get("id", "N/A")
58
+ txt = (txt[:max_chars] + "…") if len(txt) > max_chars else txt
59
+ lines.append(f"- [doc:{doc_id}] {txt}")
60
+ return "\n".join(lines)
61
+
62
+ # --- ENCODER CLASSES ---
63
  class Glot500Encoder:
64
  def __init__(self, model_id: str):
 
65
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
  self.st_model = SentenceTransformer(model_id, device=str(self.device))
67
+ print(f"Glot-500 model '{model_id}' loaded successfully.")
68
+ @torch.no_grad()
69
+ def encode(self, texts: List[str], batch_size: int = 64) -> np.ndarray:
70
+ return self.st_model.encode(texts, batch_size=batch_size, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
71
 
72
  class FaTextEncoder:
73
  def __init__(self, model_id: str, device: torch.device, max_len: int):
 
74
  self.device, self.max_len = device, max_len
75
  self.tok = AutoTokenizer.from_pretrained(model_id)
76
  self.model = AutoModel.from_pretrained(model_id).to(device).eval()
77
+ print(f"FaCLIP text model '{model_id}' loaded successfully.")
78
+ @torch.no_grad()
79
+ def encode_numpy(self, texts: List[str], batch_size: int = 128) -> np.ndarray:
80
  vecs = []
81
  for i in range(0, len(texts), batch_size):
82
  toks = self.tok(texts[i:i+batch_size], padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
 
86
  vecs.append(x_norm.detach().cpu().numpy())
87
  return np.vstack(vecs).astype(np.float32)
88
 
89
+ class FaVisionEncoder:
90
+ def __init__(self, model_id: str, device: torch.device):
91
+ self.device = device
92
+ self.model = CLIPVisionModel.from_pretrained(model_id).to(device).eval()
93
+ self.proc = CLIPImageProcessor.from_pretrained(model_id)
94
+ @torch.no_grad()
95
+ def encode(self, img: Image.Image) -> np.ndarray:
96
+ img = ImageOps.exif_transpose(img).convert("RGB")
97
+ batch = self.proc(images=img, return_tensors="pt").to(self.device)
98
+ out = self.model(**batch)
99
+ v = out.pooler_output if hasattr(out, "pooler_output") and out.pooler_output is not None else out.last_hidden_state[:,0]
100
+ v_norm = v / v.norm(p=2, dim=1, keepdim=True)
101
+ return v_norm[0].detach().cpu().numpy().astype(np.float32)
102
+
103
+ # --- RETRIEVER CLASSES ---
104
+ class BaseRetriever:
105
+ def __init__(self, docstore: pd.DataFrame, index_path: str):
106
+ self.docstore, self.index_path = docstore.reset_index(drop=True), index_path
107
+ if os.path.isfile(self.index_path):
108
+ self.index = faiss.read_index(self.index_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  else:
110
+ raise FileNotFoundError(f"Index file not found at {self.index_path}. Make sure it's uploaded to your Space.")
111
+ def search(self, query_vec: np.ndarray, k: int) -> List[Tuple[int, float]]:
112
+ D, I = self.index.search(query_vec[None, :].astype(np.float32), k)
113
+ return list(zip(I[0].tolist(), D[0].tolist()))
114
+
115
+ class Glot500Retriever(BaseRetriever):
116
+ def __init__(self, encoder: Glot500Encoder, docstore: pd.DataFrame, index_path: str):
117
+ super().__init__(docstore, index_path)
118
+ self.encoder = encoder
119
+ def topk(self, query: str, k: int) -> List[Tuple[int, float]]:
120
+ qv = self.encoder.encode([query], batch_size=1)[0]
121
+ return self.search(qv, k)
122
+
123
+ class TextIndexRetriever(BaseRetriever):
124
+ def __init__(self, text_encoder: FaTextEncoder, docstore: pd.DataFrame, index_path: str):
125
+ super().__init__(docstore, index_path)
126
+ self.encoder = text_encoder
127
+
128
+ # --- GENERATION AND SYSTEM CLASSES ---
129
+ class VLM_GenAI:
130
+ def __init__(self, api_key: str, model_name: str, temperature: float = 0.1, max_output_tokens: int = 1024):
131
+ if not api_key or "YOUR" in api_key:
132
+ raise ValueError("Gemini API Key is missing or is a placeholder. Please add it to your Hugging Face Space secrets.")
133
+ genai.configure(api_key=api_key)
134
+ self.model = genai.GenerativeModel(model_name)
135
+ self.generation_config = GenerationConfig(temperature=temperature, max_output_tokens=max_output_tokens)
136
+ self.safety_settings = {
137
+ "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
138
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
139
+ }
140
+
141
+ class RAGSystem:
142
+ def __init__(self, cfg: Config):
143
+ self.docstore = pd.read_parquet(cfg.docstore_path)
144
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145
+
146
+ self.glot_enc = Glot500Encoder(cfg.glot_model_hf)
147
+ self.glot_ret = Glot500Retriever(self.glot_enc, self.docstore, cfg.glot_index_out)
148
+
149
+ txt_enc = FaTextEncoder(cfg.mclip_text_model_hf, device, cfg.max_text_len)
150
+ self.mclip_ret = TextIndexRetriever(txt_enc, self.docstore, cfg.clip_index_out)
151
+ self.vision = FaVisionEncoder(cfg.clip_vision_model, device)
152
+
153
+ # ==============================================================================
154
+ # Part 2: Gradio Web Application
155
+ # ==============================================================================
156
+
157
+ # --- 1. LOAD MODELS AND INDEXES (This runs only once when the app starts) ---
158
+ print("Initializing configuration...")
159
+ cfg = Config()
160
+ print("Loading RAG system (models, encoders, and retrievers)...")
161
+ rag_system = RAGSystem(cfg)
162
+ print("Initializing Gemini model...")
163
+ # Securely get the API key from Hugging Face secrets
164
+ api_key = os.environ.get("GEMINI_API_KEY")
165
+ vlm = VLM_GenAI(api_key, model_name="models/gemini-1.5-flash")
166
+ print("System ready.")
167
+
168
+ # --- 2. DEFINE THE FUNCTION TO HANDLE USER INPUT ---
169
+ def run_rag_query(question_text: str, question_image: Optional[Image.Image]) -> Tuple[str, str]:
170
+ if not question_text.strip():
171
+ return "Please ask a question.", ""
172
+
173
+ context_block = ""
174
+ # Decide which retriever to use based on input
175
+ if question_image:
176
+ print("Performing multimodal retrieval...")
177
+ img_vec = rag_system.vision.encode(question_image)
178
+ hits = rag_system.mclip_ret.search(img_vec, k=cfg.per_option_ctx)
179
+ context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
180
  else:
181
+ print("Performing text retrieval...")
182
+ hits = rag_system.glot_ret.topk(question_text, k=cfg.per_option_ctx)
183
+ context_block = Utils.build_context_block(hits, rag_system.docstore, cfg.per_option_ctx)
184
+
185
+ # --- Augment and Generate ---
186
+ print("Generating response...")
187
+ if question_image:
188
+ prompt = f"با توجه به تصویر و اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
189
+ else:
190
+ prompt = f"با توجه به اسناد زیر، به سوال پاسخ دهید.\n\nاسناد:\n{context_block}\n\nسوال: {question_text}"
191
 
192
+ content_parts = [question_image, prompt] if question_image else [prompt]
 
193
 
194
+ try:
195
+ resp = vlm.model.generate_content(
196
+ content_parts,
197
+ generation_config=vlm.generation_config,
198
+ safety_settings=vlm.safety_settings
199
+ )
200
+ answer = resp.text
201
+ except Exception as e:
202
+ answer = f"Error during generation: {e}"
203
+ print(answer)
204
+
205
+ return answer, context_block
206
+
207
+ # --- 3. CREATE THE GRADIO INTERFACE ---
208
+ with gr.Blocks(theme=gr.themes.Soft(), title="Persian Culinary RAG") as demo:
209
+ gr.Markdown("# 🍲 Persian Culinary RAG Demo")
210
+ gr.Markdown("Ask a question about Iranian food, with or without an image, to see the RAG system in action.")
211
+
212
+ with gr.Row():
213
+ with gr.Column(scale=1):
214
+ image_input = gr.Image(type="pil", label="Upload an Image (Optional)")
215
+ text_input = gr.Textbox(label="Ask your question in Persian", placeholder="...مثلا: در مورد قورمه سبزی توضیح بده")
216
+ submit_button = gr.Button("Submit", variant="primary")
217
+ with gr.Column(scale=2):
218
+ output_answer = gr.Textbox(label="Answer from Model", lines=8, interactive=False)
219
+ output_context = gr.Textbox(label="Retrieved Context (What the model used to answer)", lines=12, interactive=False)
220
 
221
+ gr.Examples(
222
+ examples=[
223
+ ["در مورد حلوا توضیح بده", None],
224
+ ["مواد لازم برای تهیه آش رشته چیست؟", None],
225
+ ],
226
+ inputs=[text_input, image_input]
227
+ )
228
+
229
+ submit_button.click(
230
+ fn=run_rag_query,
231
+ inputs=[text_input, image_input],
232
+ outputs=[output_answer, output_context]
233
+ )
234
+
235
+ # Launch the web server
236
+ demo.launch()
237