ZedLow commited on
Commit
b4d9d33
·
verified ·
1 Parent(s): bdca465

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -31
app.py CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGenerati
8
  from qwen_vl_utils import process_vision_info
9
 
10
  # --- CONFIGURATION ---
11
- print(f"🚀 Démarrage RAG Finance (Version Originale + Fix Hallucination)...")
12
 
13
  # --- 1. DONNÉES ---
14
  try:
@@ -18,9 +18,9 @@ except:
18
  dataset = []
19
  print("⚠️ Index vide ou fichier non trouvé.")
20
 
21
- # --- 2. MODÈLES (INCHANGÉS) ---
22
 
23
- # A. EMBEDDING : GTE-Qwen2-7B (Le modèle LOURD original)
24
  EMBED_MODEL_ID = "Alibaba-NLP/gte-Qwen2-7B-instruct"
25
  print(f"🔹 Chargement Embedder : {EMBED_MODEL_ID}")
26
 
@@ -29,8 +29,6 @@ embed_model = AutoModel.from_pretrained(
29
  EMBED_MODEL_ID,
30
  trust_remote_code=False,
31
  torch_dtype=torch.bfloat16,
32
- # J'ai mis en commentaire la ligne qui fait planter le démarrage sur CPU (ZeroGPU)
33
- # attn_implementation="flash_attention_2",
34
  device_map="auto"
35
  )
36
 
@@ -50,13 +48,11 @@ print(f"👁️ Chargement Vision : {GEN_MODEL_ID}")
50
  gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
51
  GEN_MODEL_ID,
52
  torch_dtype=torch.bfloat16,
53
- # Idem, désactivé pour éviter le crash "No CUDA" au boot
54
- # attn_implementation="flash_attention_2",
55
  device_map="auto"
56
  )
57
  gen_processor = AutoProcessor.from_pretrained(GEN_MODEL_ID)
58
 
59
- # --- 3. FONCTIONS UTILITAIRES ---
60
 
61
  def last_token_pool(last_hidden_states, attention_mask):
62
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
@@ -67,7 +63,7 @@ def last_token_pool(last_hidden_states, attention_mask):
67
  batch_size = last_hidden_states.shape[0]
68
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
69
 
70
- # --- 4. LOGIQUE RAG MULTI-VIEW ---
71
 
72
  @spaces.GPU
73
  def retrieve_and_answer(query):
@@ -75,13 +71,38 @@ def retrieve_and_answer(query):
75
 
76
  if not dataset: return None, "Base vide", "Pas de document"
77
 
78
- # === ÉTAPE 1 : RETRIEVAL (Embedding) ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  valid_docs = []
80
  for i, doc in enumerate(dataset):
 
 
 
 
 
 
81
  text = doc.get('text', '').strip()
82
  if text:
83
- valid_docs.append({'text': text, 'original_index': i})
84
 
 
 
 
 
85
  query_text = f"Instruct: Given a user query, retrieve relevant passages that answer the query.\nQuery: {query}"
86
 
87
  with torch.no_grad():
@@ -93,7 +114,7 @@ def retrieve_and_answer(query):
93
  d_embeddings_list = []
94
  doc_texts = [d['text'] for d in valid_docs]
95
 
96
- # Batch size de 1 pour économiser la mémoire avec le gros modèle 7B
97
  for i in range(0, len(doc_texts), 1):
98
  d_inputs = embed_tokenizer(doc_texts[i:i+1], max_length=8192, padding=True, truncation=True, return_tensors='pt').to(embed_model.device)
99
  d_outputs = embed_model(**d_inputs)
@@ -103,7 +124,9 @@ def retrieve_and_answer(query):
103
 
104
  d_emb_final = torch.cat(d_embeddings_list, dim=0)
105
  scores = (q_emb @ d_emb_final.T).squeeze(0)
106
- top_k_indices = torch.topk(scores, k=min(10, len(scores))).indices.tolist()
 
 
107
 
108
  # === ÉTAPE 2 : RERANKING ===
109
  pairs = []
@@ -113,9 +136,11 @@ def retrieve_and_answer(query):
113
  with torch.no_grad():
114
  r_inputs = rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=8192).to(rerank_model.device)
115
  r_scores = rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
116
- top_3_indices_local = torch.topk(r_scores, k=min(3, len(r_scores))).indices.tolist()
 
 
117
 
118
- # === ÉTAPE 3 : PRÉPARATION IMAGES (ICI ON CORRIGE L'HALLUCINATION) ===
119
  images_content = []
120
  gallery_preview = []
121
  meta_info = ""
@@ -128,31 +153,32 @@ def retrieve_and_answer(query):
128
  doc = dataset[final_doc_idx]
129
  image_path = doc['image_path']
130
  score = r_scores[idx_local].item()
 
131
 
132
  try:
133
  img = Image.open(image_path)
134
 
135
- # --- FIX HALLUCINATION ---
136
- # On récupère le nom du document (ex: "Microsoft 2023 Report")
137
- doc_name = doc.get('doc_name', 'Unknown Document')
138
-
139
- # On l'injecte explicitement dans le texte que voit l'IA
140
- prompt_header = f"DOCUMENT SOURCE: {doc_name} (Relevance: {score:.2f})\n"
141
 
142
- images_content.append({"type": "text", "text": prompt_header})
143
  images_content.append({"type": "image", "image": img})
144
 
145
- gallery_preview.append((img, f"{doc_name} (Rank {rank+1})"))
146
  meta_info += f"- **{doc_name}** (Score: {score:.2f})\n"
147
  except:
148
  continue
149
 
150
- # === ÉTAPE 4 : GÉNÉRATION ===
151
- # On renforce le prompt système pour qu'il fasse attention au nom du document
 
 
152
  system_prompt = (
153
- "You are an expert financial analyst. Answer the user question using ONLY the provided images.\n"
154
- "IMPORTANT: Before reading a table, check the 'DOCUMENT SOURCE' name above the image.\n"
155
- "If the user asks about Microsoft, do not use data from an Apple document (and vice versa)."
 
 
156
  )
157
 
158
  user_content = images_content + [{"type": "text", "text": f"\nUser Question: {query}"}]
@@ -172,7 +198,7 @@ def retrieve_and_answer(query):
172
  return_tensors="pt",
173
  ).to(gen_model.device)
174
 
175
- generated_ids = gen_model.generate(**inputs, max_new_tokens=768)
176
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
177
  response = gen_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
178
 
@@ -180,10 +206,10 @@ def retrieve_and_answer(query):
180
 
181
  # --- 5. UI ---
182
  with gr.Blocks(title="RAG Finance") as demo:
183
- gr.Markdown("# 🚀 RAG Finance (Moteurs Originaux + Sécurité Hallucination)")
184
 
185
  with gr.Row():
186
- query_input = gr.Textbox(label="Question", placeholder="Ex: What is the revenue of Microsoft?")
187
  submit_btn = gr.Button("Analyser", variant="primary")
188
 
189
  with gr.Row():
 
8
  from qwen_vl_utils import process_vision_info
9
 
10
  # --- CONFIGURATION ---
11
+ print(f"🚀 Démarrage RAG Finance (Version Originale + FILTRAGE STRICT)...")
12
 
13
  # --- 1. DONNÉES ---
14
  try:
 
18
  dataset = []
19
  print("⚠️ Index vide ou fichier non trouvé.")
20
 
21
+ # --- 2. MODÈLES (INCHANGÉS - ON GARDE LES GROS) ---
22
 
23
+ # A. EMBEDDING
24
  EMBED_MODEL_ID = "Alibaba-NLP/gte-Qwen2-7B-instruct"
25
  print(f"🔹 Chargement Embedder : {EMBED_MODEL_ID}")
26
 
 
29
  EMBED_MODEL_ID,
30
  trust_remote_code=False,
31
  torch_dtype=torch.bfloat16,
 
 
32
  device_map="auto"
33
  )
34
 
 
48
  gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
49
  GEN_MODEL_ID,
50
  torch_dtype=torch.bfloat16,
 
 
51
  device_map="auto"
52
  )
53
  gen_processor = AutoProcessor.from_pretrained(GEN_MODEL_ID)
54
 
55
+ # --- 3. FONCTIONS ---
56
 
57
  def last_token_pool(last_hidden_states, attention_mask):
58
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
 
63
  batch_size = last_hidden_states.shape[0]
64
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
65
 
66
+ # --- 4. PIPELINE ---
67
 
68
  @spaces.GPU
69
  def retrieve_and_answer(query):
 
71
 
72
  if not dataset: return None, "Base vide", "Pas de document"
73
 
74
+ # === ÉTAPE 0 : FILTRAGE STRICT (LA SÉCURITÉ) ===
75
+ # On regarde si l'utilisateur parle d'une entreprise spécifique
76
+ # Si oui, on retire TOUTES les autres pages de la liste.
77
+
78
+ query_lower = query.lower()
79
+ target_company = None
80
+
81
+ if "microsoft" in query_lower or "msft" in query_lower:
82
+ target_company = "Microsoft"
83
+ elif "apple" in query_lower or "aapl" in query_lower:
84
+ target_company = "Apple"
85
+ elif "tesla" in query_lower:
86
+ # Cas piège Tesla : On sait qu'on n'a pas les docs, on coupe tout de suite.
87
+ return [], "", "Data not found: No documents available for Tesla in the database."
88
+
89
+ # === ÉTAPE 1 : RETRIEVAL ===
90
  valid_docs = []
91
  for i, doc in enumerate(dataset):
92
+ # Le Filtrage Strict s'applique ici
93
+ if target_company:
94
+ # Si on cherche Microsoft, on ignore tout ce qui ne contient pas "Microsoft" dans le nom du doc
95
+ if target_company not in doc.get('doc_name', ''):
96
+ continue
97
+
98
  text = doc.get('text', '').strip()
99
  if text:
100
+ valid_docs.append({'text': text, 'original_index': i, 'doc_name': doc.get('doc_name', 'Doc')})
101
 
102
+ # Si après filtrage on a plus rien (ex: question sur Tesla mal gérée avant), on arrête
103
+ if not valid_docs:
104
+ return [], "", "No relevant documents found for this company."
105
+
106
  query_text = f"Instruct: Given a user query, retrieve relevant passages that answer the query.\nQuery: {query}"
107
 
108
  with torch.no_grad():
 
114
  d_embeddings_list = []
115
  doc_texts = [d['text'] for d in valid_docs]
116
 
117
+ # Batch size 1 pour le gros modèle 7B
118
  for i in range(0, len(doc_texts), 1):
119
  d_inputs = embed_tokenizer(doc_texts[i:i+1], max_length=8192, padding=True, truncation=True, return_tensors='pt').to(embed_model.device)
120
  d_outputs = embed_model(**d_inputs)
 
124
 
125
  d_emb_final = torch.cat(d_embeddings_list, dim=0)
126
  scores = (q_emb @ d_emb_final.T).squeeze(0)
127
+ # On prend max 10 ou moins si on a filtré
128
+ k_val = min(10, len(scores))
129
+ top_k_indices = torch.topk(scores, k=k_val).indices.tolist()
130
 
131
  # === ÉTAPE 2 : RERANKING ===
132
  pairs = []
 
136
  with torch.no_grad():
137
  r_inputs = rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=8192).to(rerank_model.device)
138
  r_scores = rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
139
+
140
+ k_rerank = min(3, len(r_scores))
141
+ top_3_indices_local = torch.topk(r_scores, k=k_rerank).indices.tolist()
142
 
143
+ # === ÉTAPE 3 : PRÉPARATION IMAGES ===
144
  images_content = []
145
  gallery_preview = []
146
  meta_info = ""
 
153
  doc = dataset[final_doc_idx]
154
  image_path = doc['image_path']
155
  score = r_scores[idx_local].item()
156
+ doc_name = doc.get('doc_name', 'Unknown')
157
 
158
  try:
159
  img = Image.open(image_path)
160
 
161
+ # Injection du nom pour aider encore plus
162
+ header_text = f"SOURCE DOCUMENT: {doc_name} (Confidence: {score:.2f})\n"
 
 
 
 
163
 
164
+ images_content.append({"type": "text", "text": header_text})
165
  images_content.append({"type": "image", "image": img})
166
 
167
+ gallery_preview.append((img, f"{doc_name}"))
168
  meta_info += f"- **{doc_name}** (Score: {score:.2f})\n"
169
  except:
170
  continue
171
 
172
+ if not images_content:
173
+ return [], "", "No images found."
174
+
175
+ # === 4. GENERATION ===
176
  system_prompt = (
177
+ "You are a strict financial analyst. Answer the user question using ONLY the provided images.\n"
178
+ "RULES:\n"
179
+ "1. If the user asks for 'Microsoft', ONLY use the image labeled 'Microsoft'. IGNORE Apple.\n"
180
+ "2. If the user asks for 'Apple', ONLY use the image labeled 'Apple'. IGNORE Microsoft.\n"
181
+ "3. Copy the exact number from the image. Do not calculate."
182
  )
183
 
184
  user_content = images_content + [{"type": "text", "text": f"\nUser Question: {query}"}]
 
198
  return_tensors="pt",
199
  ).to(gen_model.device)
200
 
201
+ generated_ids = gen_model.generate(**inputs, max_new_tokens=512)
202
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
203
  response = gen_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
204
 
 
206
 
207
  # --- 5. UI ---
208
  with gr.Blocks(title="RAG Finance") as demo:
209
+ gr.Markdown("# 🚀 RAG Finance (Version Sécurisée)")
210
 
211
  with gr.Row():
212
+ query_input = gr.Textbox(label="Question", placeholder="Ex: What is the Operating Income for Microsoft?")
213
  submit_btn = gr.Button("Analyser", variant="primary")
214
 
215
  with gr.Row():