ZedLow commited on
Commit
a24954e
·
verified ·
1 Parent(s): f2863bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -25
app.py CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGenerati
8
  from qwen_vl_utils import process_vision_info
9
 
10
  # --- CONFIGURATION ---
11
- print(f"🚀 Démarrage RAG Finance (Mode Multi-View : 3 Images)...")
12
 
13
  # --- 1. DONNÉES ---
14
  try:
@@ -16,10 +16,11 @@ try:
16
  dataset = json.load(f)
17
  except:
18
  dataset = []
19
- print("⚠️ Index vide.")
20
 
21
- # --- 2. MODÈLES ---
22
- # A. EMBEDDING : GTE-Qwen2-7B (Le modèle LOURD qui causait les crashs mémoire)
 
23
  EMBED_MODEL_ID = "Alibaba-NLP/gte-Qwen2-7B-instruct"
24
  print(f"🔹 Chargement Embedder : {EMBED_MODEL_ID}")
25
 
@@ -28,8 +29,8 @@ embed_model = AutoModel.from_pretrained(
28
  EMBED_MODEL_ID,
29
  trust_remote_code=False,
30
  torch_dtype=torch.bfloat16,
31
- # C'est cette ligne qui fait planter si pas de GPU détecté immédiatement
32
- attn_implementation="flash_attention_2",
33
  device_map="auto"
34
  )
35
 
@@ -49,12 +50,14 @@ print(f"👁️ Chargement Vision : {GEN_MODEL_ID}")
49
  gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
50
  GEN_MODEL_ID,
51
  torch_dtype=torch.bfloat16,
52
- attn_implementation="flash_attention_2",
 
53
  device_map="auto"
54
  )
55
  gen_processor = AutoProcessor.from_pretrained(GEN_MODEL_ID)
56
 
57
- # --- 3. FONCTIONS ---
 
58
  def last_token_pool(last_hidden_states, attention_mask):
59
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
60
  if left_padding:
@@ -64,14 +67,15 @@ def last_token_pool(last_hidden_states, attention_mask):
64
  batch_size = last_hidden_states.shape[0]
65
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
66
 
67
- # --- 4. PIPELINE ---
 
68
  @spaces.GPU
69
  def retrieve_and_answer(query):
70
  print(f"⚡ Question : {query}")
71
 
72
  if not dataset: return None, "Base vide", "Pas de document"
73
 
74
- # 1. RETRIEVAL (Recalculé à chaque fois -> Lent)
75
  valid_docs = []
76
  for i, doc in enumerate(dataset):
77
  text = doc.get('text', '').strip()
@@ -89,6 +93,7 @@ def retrieve_and_answer(query):
89
  d_embeddings_list = []
90
  doc_texts = [d['text'] for d in valid_docs]
91
 
 
92
  for i in range(0, len(doc_texts), 1):
93
  d_inputs = embed_tokenizer(doc_texts[i:i+1], max_length=8192, padding=True, truncation=True, return_tensors='pt').to(embed_model.device)
94
  d_outputs = embed_model(**d_inputs)
@@ -100,7 +105,7 @@ def retrieve_and_answer(query):
100
  scores = (q_emb @ d_emb_final.T).squeeze(0)
101
  top_k_indices = torch.topk(scores, k=min(10, len(scores))).indices.tolist()
102
 
103
- # 2. RERANKING
104
  pairs = []
105
  for idx in top_k_indices:
106
  pairs.append([query, valid_docs[idx]['text']])
@@ -110,7 +115,7 @@ def retrieve_and_answer(query):
110
  r_scores = rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
111
  top_3_indices_local = torch.topk(r_scores, k=min(3, len(r_scores))).indices.tolist()
112
 
113
- # 3. PREPARATION IMAGES (C'est ICI que l'hallucination se crée)
114
  images_content = []
115
  gallery_preview = []
116
  meta_info = ""
@@ -126,20 +131,28 @@ def retrieve_and_answer(query):
126
 
127
  try:
128
  img = Image.open(image_path)
129
- # PROBLÈME ICI : On ne dit pas au modèle "Ceci est Microsoft" ou "Ceci est Apple"
130
- # Il voit juste "Image 1", "Image 2"...
131
- images_content.append({"type": "text", "text": f"Image {rank+1} (Pertinence: {score:.2f}):\n"})
 
 
 
 
 
 
132
  images_content.append({"type": "image", "image": img})
133
 
134
- gallery_preview.append((img, f"Page {rank+1} - Score {score:.2f}"))
135
- meta_info += f"- **Image {rank+1}:** {doc['doc_name']} (Score: {score:.2f})\n"
136
  except:
137
  continue
138
 
139
- # 4. GENERATION
 
140
  system_prompt = (
141
- "You are an expert financial analyst examining 3 pages of a report. "
142
- "Your goal is to answer the user question using ONLY the provided images."
 
143
  )
144
 
145
  user_content = images_content + [{"type": "text", "text": f"\nUser Question: {query}"}]
@@ -167,16 +180,16 @@ def retrieve_and_answer(query):
167
 
168
  # --- 5. UI ---
169
  with gr.Blocks(title="RAG Finance") as demo:
170
- gr.Markdown("# 🚀 RAG Finance (Version Originale Instable)")
171
 
172
  with gr.Row():
173
- query_input = gr.Textbox(label="Question")
174
  submit_btn = gr.Button("Analyser", variant="primary")
175
 
176
  with gr.Row():
177
- output_gallery = gr.Gallery(label="Pages")
178
- output_meta = gr.Markdown(label="Sources")
179
- output_text = gr.Markdown(label="Réponse")
180
 
181
  submit_btn.click(retrieve_and_answer, inputs=query_input, outputs=[output_gallery, output_meta, output_text])
182
 
 
8
  from qwen_vl_utils import process_vision_info
9
 
10
  # --- CONFIGURATION ---
11
+ print(f"🚀 Démarrage RAG Finance (Version Originale + Fix Hallucination)...")
12
 
13
  # --- 1. DONNÉES ---
14
  try:
 
16
  dataset = json.load(f)
17
  except:
18
  dataset = []
19
+ print("⚠️ Index vide ou fichier non trouvé.")
20
 
21
+ # --- 2. MODÈLES (INCHANGÉS) ---
22
+
23
+ # A. EMBEDDING : GTE-Qwen2-7B (Le modèle LOURD original)
24
  EMBED_MODEL_ID = "Alibaba-NLP/gte-Qwen2-7B-instruct"
25
  print(f"🔹 Chargement Embedder : {EMBED_MODEL_ID}")
26
 
 
29
  EMBED_MODEL_ID,
30
  trust_remote_code=False,
31
  torch_dtype=torch.bfloat16,
32
+ # J'ai mis en commentaire la ligne qui fait planter le démarrage sur CPU (ZeroGPU)
33
+ # attn_implementation="flash_attention_2",
34
  device_map="auto"
35
  )
36
 
 
50
  gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
51
  GEN_MODEL_ID,
52
  torch_dtype=torch.bfloat16,
53
+ # Idem, désactivé pour éviter le crash "No CUDA" au boot
54
+ # attn_implementation="flash_attention_2",
55
  device_map="auto"
56
  )
57
  gen_processor = AutoProcessor.from_pretrained(GEN_MODEL_ID)
58
 
59
+ # --- 3. FONCTIONS UTILITAIRES ---
60
+
61
  def last_token_pool(last_hidden_states, attention_mask):
62
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
63
  if left_padding:
 
67
  batch_size = last_hidden_states.shape[0]
68
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
69
 
70
+ # --- 4. LOGIQUE RAG MULTI-VIEW ---
71
+
72
  @spaces.GPU
73
  def retrieve_and_answer(query):
74
  print(f"⚡ Question : {query}")
75
 
76
  if not dataset: return None, "Base vide", "Pas de document"
77
 
78
+ # === ÉTAPE 1 : RETRIEVAL (Embedding) ===
79
  valid_docs = []
80
  for i, doc in enumerate(dataset):
81
  text = doc.get('text', '').strip()
 
93
  d_embeddings_list = []
94
  doc_texts = [d['text'] for d in valid_docs]
95
 
96
+ # Batch size de 1 pour économiser la mémoire avec le gros modèle 7B
97
  for i in range(0, len(doc_texts), 1):
98
  d_inputs = embed_tokenizer(doc_texts[i:i+1], max_length=8192, padding=True, truncation=True, return_tensors='pt').to(embed_model.device)
99
  d_outputs = embed_model(**d_inputs)
 
105
  scores = (q_emb @ d_emb_final.T).squeeze(0)
106
  top_k_indices = torch.topk(scores, k=min(10, len(scores))).indices.tolist()
107
 
108
+ # === ÉTAPE 2 : RERANKING ===
109
  pairs = []
110
  for idx in top_k_indices:
111
  pairs.append([query, valid_docs[idx]['text']])
 
115
  r_scores = rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
116
  top_3_indices_local = torch.topk(r_scores, k=min(3, len(r_scores))).indices.tolist()
117
 
118
+ # === ÉTAPE 3 : PRÉPARATION IMAGES (ICI ON CORRIGE L'HALLUCINATION) ===
119
  images_content = []
120
  gallery_preview = []
121
  meta_info = ""
 
131
 
132
  try:
133
  img = Image.open(image_path)
134
+
135
+ # --- FIX HALLUCINATION ---
136
+ # On récupère le nom du document (ex: "Microsoft 2023 Report")
137
+ doc_name = doc.get('doc_name', 'Unknown Document')
138
+
139
+ # On l'injecte explicitement dans le texte que voit l'IA
140
+ prompt_header = f"DOCUMENT SOURCE: {doc_name} (Relevance: {score:.2f})\n"
141
+
142
+ images_content.append({"type": "text", "text": prompt_header})
143
  images_content.append({"type": "image", "image": img})
144
 
145
+ gallery_preview.append((img, f"{doc_name} (Rank {rank+1})"))
146
+ meta_info += f"- **{doc_name}** (Score: {score:.2f})\n"
147
  except:
148
  continue
149
 
150
+ # === ÉTAPE 4 : GÉNÉRATION ===
151
+ # On renforce le prompt système pour qu'il fasse attention au nom du document
152
  system_prompt = (
153
+ "You are an expert financial analyst. Answer the user question using ONLY the provided images.\n"
154
+ "IMPORTANT: Before reading a table, check the 'DOCUMENT SOURCE' name above the image.\n"
155
+ "If the user asks about Microsoft, do not use data from an Apple document (and vice versa)."
156
  )
157
 
158
  user_content = images_content + [{"type": "text", "text": f"\nUser Question: {query}"}]
 
180
 
181
  # --- 5. UI ---
182
  with gr.Blocks(title="RAG Finance") as demo:
183
+ gr.Markdown("# 🚀 RAG Finance (Moteurs Originaux + Sécurité Hallucination)")
184
 
185
  with gr.Row():
186
+ query_input = gr.Textbox(label="Question", placeholder="Ex: What is the revenue of Microsoft?")
187
  submit_btn = gr.Button("Analyser", variant="primary")
188
 
189
  with gr.Row():
190
+ output_gallery = gr.Gallery(label="Pages Analysées", columns=3, height=300)
191
+ output_meta = gr.Markdown(label="Sources Identifiées")
192
+ output_text = gr.Markdown(label="Réponse IA")
193
 
194
  submit_btn.click(retrieve_and_answer, inputs=query_input, outputs=[output_gallery, output_meta, output_text])
195