ZedLow commited on
Commit
f2863bc
·
verified ·
1 Parent(s): 97ef6d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -218
app.py CHANGED
@@ -1,240 +1,184 @@
1
- import os
2
- import sys
3
- import json
4
- import time
5
- import logging
6
- import torch
7
  import spaces
 
8
  import gradio as gr
 
9
  import torch.nn.functional as F
10
  from PIL import Image
11
  from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForSequenceClassification
12
  from qwen_vl_utils import process_vision_info
13
 
14
- # --- LOGGING ---
15
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
- logger = logging.getLogger(__name__)
17
-
18
- # --- CONFIGURATION (MODE STABLE) ---
19
- CONFIG = {
20
- # On reste sur le 1.5B : C'est le SEUL qui ne fait pas crasher ZeroGPU au démarrage
21
- "embedding_model": "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
22
- "rerank_model": "BAAI/bge-reranker-v2-m3",
23
- "vision_model": "Qwen/Qwen2-VL-2B-Instruct",
24
- "data_path": "data/dataset.json",
25
- "allowed_image_dir": "data",
26
- "max_embed_len": 2048,
27
- "max_rerank_len": 512
28
- }
29
-
30
- # --- PATCH ---
31
- def apply_patches():
32
- import transformers
33
- if not hasattr(transformers.PreTrainedModel, "all_tied_weights_keys"):
34
- setattr(transformers.PreTrainedModel, "all_tied_weights_keys", {})
35
-
36
- # --- ENGINE CLASS ---
37
- class FinancialAnalystEngine:
38
- def __init__(self):
39
- logger.info("🏗️ Initializing Engine...")
40
- apply_patches()
41
-
42
- self.dataset = []
43
- self.doc_embeddings = None
44
- self.load_data()
45
-
46
- logger.info("🔹 Loading Models (CPU Mode for Stability)...")
47
-
48
- # 1. Chargement CPU (Vital pour ne pas avoir l'erreur "No CUDA")
49
- self.embed_tokenizer = AutoTokenizer.from_pretrained(CONFIG["embedding_model"], trust_remote_code=False)
50
- self.embed_model = AutoModel.from_pretrained(CONFIG["embedding_model"], trust_remote_code=False, torch_dtype=torch.float16).eval()
51
-
52
- self.rerank_tokenizer = AutoTokenizer.from_pretrained(CONFIG["rerank_model"])
53
- self.rerank_model = AutoModelForSequenceClassification.from_pretrained(CONFIG["rerank_model"], torch_dtype=torch.float16).eval()
54
-
55
- self.vision_processor = AutoProcessor.from_pretrained(CONFIG["vision_model"])
56
- # Pas de flash_attention_2 ici, c'est ça qui causait ton autre crash
57
- self.vision_model = Qwen2VLForConditionalGeneration.from_pretrained(CONFIG["vision_model"], torch_dtype=torch.float16).eval()
58
-
59
- # 2. Indexation immédiate
60
- self.index_documents()
61
-
62
- logger.info("🚀 Engine Ready.")
63
-
64
- def load_data(self):
65
- try:
66
- with open(CONFIG["data_path"], "r", encoding="utf-8") as f:
67
- self.dataset = json.load(f)
68
- logger.info(f"📂 Dataset loaded: {len(self.dataset)} documents.")
69
- except Exception as e:
70
- logger.error(f"❌ Failed to load dataset: {e}")
71
- self.dataset = []
72
-
73
- def validate_image_path(self, path):
74
- clean_path = os.path.abspath(path)
75
- allowed_path = os.path.abspath(CONFIG["allowed_image_dir"])
76
- if not clean_path.startswith(allowed_path):
77
- return None
78
- return clean_path
79
-
80
- def last_token_pool(self, last_hidden_states, attention_mask):
81
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
82
- if left_padding:
83
- return last_hidden_states[:, -1]
84
- else:
85
- sequence_lengths = attention_mask.sum(dim=1) - 1
86
- batch_size = last_hidden_states.shape[0]
87
- # Sécurité pour éviter les erreurs de device
88
- sequence_lengths = sequence_lengths.to(last_hidden_states.device)
89
- return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
90
-
91
- def index_documents(self):
92
- if not self.dataset: return
93
- logger.info("⚙️ Indexing documents...")
94
- texts = [d.get('text', '') for d in self.dataset]
95
- embeddings = []
96
- batch_size = 4
97
-
98
- with torch.no_grad():
99
- for i in range(0, len(texts), batch_size):
100
- batch = texts[i : i + batch_size]
101
- inputs = self.embed_tokenizer(
102
- batch, max_length=CONFIG["max_embed_len"], padding=True, truncation=True, return_tensors="pt"
103
- )
104
- outputs = self.embed_model(**inputs)
105
- emb = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
106
- emb = F.normalize(emb, p=2, dim=1)
107
- embeddings.append(emb)
108
-
109
- if embeddings:
110
- self.doc_embeddings = torch.cat(embeddings, dim=0)
111
- logger.info(f"✅ Indexing complete. Shape: {self.doc_embeddings.shape}")
112
-
113
- def pipeline(self, query):
114
- start_time = time.time()
115
-
116
- # ZeroGPU active le GPU ici. On vérifie s'il est là.
117
- device = "cuda" if torch.cuda.is_available() else "cpu"
118
-
119
- # Transfert des modèles vers le GPU (Just-in-Time)
120
- self.embed_model.to(device)
121
- self.rerank_model.to(device)
122
- self.vision_model.to(device)
123
- if self.doc_embeddings is not None:
124
- self.doc_embeddings = self.doc_embeddings.to(device)
125
-
126
- if not self.dataset or self.doc_embeddings is None:
127
- return [], "System not initialized.", ""
128
-
129
- # === 1. RETRIEVAL ===
130
- query_prompt = f"Instruct: Given a user query, retrieve relevant passages that answer the query.\nQuery: {query}"
131
-
132
- with torch.no_grad():
133
- q_inputs = self.embed_tokenizer([query_prompt], max_length=CONFIG["max_embed_len"], truncation=True, return_tensors="pt").to(device)
134
- q_out = self.embed_model(**q_inputs)
135
- q_emb = self.last_token_pool(q_out.last_hidden_state, q_inputs['attention_mask'])
136
- q_emb = F.normalize(q_emb, p=2, dim=1)
137
 
138
- scores = (q_emb @ self.doc_embeddings.T).squeeze(0)
139
- top_k_indices = torch.topk(scores, k=min(10, len(scores))).indices.tolist()
140
-
141
- # === 2. RERANKING ===
142
- pairs = [[query, self.dataset[idx]['text']] for idx in top_k_indices]
 
 
143
 
144
- with torch.no_grad():
145
- r_inputs = self.rerank_tokenizer(pairs, padding=True, truncation=True, max_length=CONFIG["max_rerank_len"], return_tensors="pt").to(device)
146
- r_scores = self.rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
147
- top_3_indices_local = torch.topk(r_scores, k=min(3, len(r_scores))).indices.tolist()
148
-
149
- # === 3. CONTEXT & IMAGES (LE FIX ANTI-HALLUCINATION) ===
150
- images_content = []
151
- gallery_data = []
152
- sources_md = "### 📚 Verified Sources\n\n"
153
 
154
- for rank, idx_local in enumerate(top_3_indices_local):
155
- global_idx = top_k_indices[idx_local]
156
- doc = self.dataset[global_idx]
157
- score = r_scores[idx_local].item()
158
-
159
- valid_path = self.validate_image_path(doc['image_path'])
160
- if not valid_path: continue
161
 
162
- try:
163
- img = Image.open(valid_path)
164
-
165
- # --- LE FIX EST ICI ---
166
- # On écrit en GROS le nom du document pour l'IA
167
- doc_name = doc.get('doc_name', 'Unknown Document')
168
- doc_section = doc.get('section', 'Unknown Section')
169
-
170
- context_header = (
171
- f"\n--- DOCUMENT {rank+1} METADATA ---\n"
172
- f"FILE NAME: {doc_name}\n" # Ex: Microsoft_2023_Report
173
- f"SECTION: {doc_section}\n"
174
- f"RELEVANCE: {score:.2f}\n"
175
- "---------------------------\n"
176
- )
177
-
178
- images_content.append({"type": "text", "text": context_header})
179
- images_content.append({"type": "image", "image": img})
180
-
181
- gallery_data.append((img, f"{doc_name}"))
182
- sources_md += f"**{rank+1}. {doc_name}** - *{doc_section}* (Score: {score:.2f})\n"
183
- except Exception as e:
184
- logger.error(f"Image load error: {e}")
185
- continue
186
-
187
- # === 4. GENERATION ===
188
- # Prompt Strict pour forcer la lecture du header
189
- system_prompt = (
190
- "You are a strict financial data extraction engine. "
191
- "Analyze the provided images to answer the user query.\n"
192
- "CRITICAL RULES:\n"
193
- "1. Read the 'DOCUMENT METADATA' provided before each image.\n"
194
- "2. If the user asks about 'Microsoft', ONLY use images labeled as Microsoft/MSFT.\n"
195
- "3. If the user asks about 'Apple', ONLY use images labeled as Apple/AAPL.\n"
196
- "4. Do not mix data between companies.\n"
197
- "Output format:\n- **Answer**: [Direct Answer]\n- **Evidence**: [Quote]\n- **Context**: [Year/Company]"
198
- )
199
 
200
- messages = [
201
- {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
202
- {"role": "user", "content": images_content + [{"type": "text", "text": f"Query: {query}"}]}
203
- ]
 
 
 
 
 
 
 
 
 
 
204
 
205
- text_input = self.vision_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
206
- inputs = self.vision_processor(text=[text_input], images=process_vision_info(messages)[0], padding=True, return_tensors="pt").to(device)
 
207
 
208
- generated_ids = self.vision_model.generate(**inputs, max_new_tokens=512, temperature=0.1)
209
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
210
- response = self.vision_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
211
-
212
- logger.info(f"⏱️ Total Latency: {time.time() - start_time:.2f}s")
213
- return gallery_data, sources_md, response
214
-
215
- # --- INSTANTIATION ---
216
- engine = FinancialAnalystEngine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- # --- UI ---
219
- @spaces.GPU(duration=60)
220
- def run_query(query):
221
- return engine.pipeline(query)
222
 
223
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
224
- gr.Markdown("# ⚡ AI Financial Analyst (Safe Mode)")
 
225
 
226
  with gr.Row():
227
- inp = gr.Textbox(label="Question", placeholder="Ex: What is the Operating Income for Microsoft?", scale=4)
228
- btn = gr.Button("Analyze", variant="primary", scale=1)
229
-
230
  with gr.Row():
231
- with gr.Column(scale=2):
232
- out_gallery = gr.Gallery(label="Documents", columns=3, height=400)
233
- with gr.Column(scale=1):
234
- out_meta = gr.Markdown(label="Sources")
235
- out_resp = gr.Markdown(label="Answer")
236
-
237
- btn.click(run_query, inp, [out_gallery, out_meta, out_resp])
238
 
239
  if __name__ == "__main__":
240
  demo.launch()
 
 
 
 
 
 
 
1
  import spaces
2
+ import torch
3
  import gradio as gr
4
+ import json
5
  import torch.nn.functional as F
6
  from PIL import Image
7
  from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForSequenceClassification
8
  from qwen_vl_utils import process_vision_info
9
 
10
+ # --- CONFIGURATION ---
11
+ print(f"🚀 Démarrage RAG Finance (Mode Multi-View : 3 Images)...")
12
+
13
+ # --- 1. DONNÉES ---
14
+ try:
15
+ with open("data/dataset.json", "r", encoding="utf-8") as f:
16
+ dataset = json.load(f)
17
+ except:
18
+ dataset = []
19
+ print("⚠️ Index vide.")
20
+
21
+ # --- 2. MODÈLES ---
22
+ # A. EMBEDDING : GTE-Qwen2-7B (Le modèle LOURD qui causait les crashs mémoire)
23
+ EMBED_MODEL_ID = "Alibaba-NLP/gte-Qwen2-7B-instruct"
24
+ print(f"🔹 Chargement Embedder : {EMBED_MODEL_ID}")
25
+
26
+ embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID, trust_remote_code=False)
27
+ embed_model = AutoModel.from_pretrained(
28
+ EMBED_MODEL_ID,
29
+ trust_remote_code=False,
30
+ torch_dtype=torch.bfloat16,
31
+ # C'est cette ligne qui fait planter si pas de GPU détecté immédiatement
32
+ attn_implementation="flash_attention_2",
33
+ device_map="auto"
34
+ )
35
+
36
+ # B. RERANKER
37
+ RERANK_MODEL_ID = "BAAI/bge-reranker-v2-m3"
38
+ print(f"⚖️ Chargement Reranker : {RERANK_MODEL_ID}")
39
+ rerank_tokenizer = AutoTokenizer.from_pretrained(RERANK_MODEL_ID)
40
+ rerank_model = AutoModelForSequenceClassification.from_pretrained(
41
+ RERANK_MODEL_ID,
42
+ torch_dtype=torch.bfloat16,
43
+ device_map="auto"
44
+ )
45
+
46
+ # C. VISION
47
+ GEN_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
48
+ print(f"👁️ Chargement Vision : {GEN_MODEL_ID}")
49
+ gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
50
+ GEN_MODEL_ID,
51
+ torch_dtype=torch.bfloat16,
52
+ attn_implementation="flash_attention_2",
53
+ device_map="auto"
54
+ )
55
+ gen_processor = AutoProcessor.from_pretrained(GEN_MODEL_ID)
56
+
57
+ # --- 3. FONCTIONS ---
58
+ def last_token_pool(last_hidden_states, attention_mask):
59
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
60
+ if left_padding:
61
+ return last_hidden_states[:, -1]
62
+ else:
63
+ sequence_lengths = attention_mask.sum(dim=1) - 1
64
+ batch_size = last_hidden_states.shape[0]
65
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
66
+
67
+ # --- 4. PIPELINE ---
68
+ @spaces.GPU
69
+ def retrieve_and_answer(query):
70
+ print(f"⚡ Question : {query}")
71
+
72
+ if not dataset: return None, "Base vide", "Pas de document"
73
+
74
+ # 1. RETRIEVAL (Recalculé à chaque fois -> Lent)
75
+ valid_docs = []
76
+ for i, doc in enumerate(dataset):
77
+ text = doc.get('text', '').strip()
78
+ if text:
79
+ valid_docs.append({'text': text, 'original_index': i})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ query_text = f"Instruct: Given a user query, retrieve relevant passages that answer the query.\nQuery: {query}"
82
+
83
+ with torch.no_grad():
84
+ q_inputs = embed_tokenizer([query_text], max_length=8192, padding=True, truncation=True, return_tensors='pt').to(embed_model.device)
85
+ q_outputs = embed_model(**q_inputs)
86
+ q_emb = last_token_pool(q_outputs.last_hidden_state, q_inputs['attention_mask'])
87
+ q_emb = F.normalize(q_emb, p=2, dim=1)
88
 
89
+ d_embeddings_list = []
90
+ doc_texts = [d['text'] for d in valid_docs]
 
 
 
 
 
 
 
91
 
92
+ for i in range(0, len(doc_texts), 1):
93
+ d_inputs = embed_tokenizer(doc_texts[i:i+1], max_length=8192, padding=True, truncation=True, return_tensors='pt').to(embed_model.device)
94
+ d_outputs = embed_model(**d_inputs)
95
+ batch_emb = last_token_pool(d_outputs.last_hidden_state, d_inputs['attention_mask'])
96
+ batch_emb = F.normalize(batch_emb, p=2, dim=1)
97
+ d_embeddings_list.append(batch_emb)
 
98
 
99
+ d_emb_final = torch.cat(d_embeddings_list, dim=0)
100
+ scores = (q_emb @ d_emb_final.T).squeeze(0)
101
+ top_k_indices = torch.topk(scores, k=min(10, len(scores))).indices.tolist()
102
+
103
+ # 2. RERANKING
104
+ pairs = []
105
+ for idx in top_k_indices:
106
+ pairs.append([query, valid_docs[idx]['text']])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ with torch.no_grad():
109
+ r_inputs = rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=8192).to(rerank_model.device)
110
+ r_scores = rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
111
+ top_3_indices_local = torch.topk(r_scores, k=min(3, len(r_scores))).indices.tolist()
112
+
113
+ # 3. PREPARATION IMAGES (C'est ICI que l'hallucination se crée)
114
+ images_content = []
115
+ gallery_preview = []
116
+ meta_info = ""
117
+
118
+ for rank, idx_local in enumerate(top_3_indices_local):
119
+ idx_in_top_k = idx_local
120
+ idx_in_valid = top_k_indices[idx_in_top_k]
121
+ final_doc_idx = valid_docs[idx_in_valid]['original_index']
122
 
123
+ doc = dataset[final_doc_idx]
124
+ image_path = doc['image_path']
125
+ score = r_scores[idx_local].item()
126
 
127
+ try:
128
+ img = Image.open(image_path)
129
+ # PROBLÈME ICI : On ne dit pas au modèle "Ceci est Microsoft" ou "Ceci est Apple"
130
+ # Il voit juste "Image 1", "Image 2"...
131
+ images_content.append({"type": "text", "text": f"Image {rank+1} (Pertinence: {score:.2f}):\n"})
132
+ images_content.append({"type": "image", "image": img})
133
+
134
+ gallery_preview.append((img, f"Page {rank+1} - Score {score:.2f}"))
135
+ meta_info += f"- **Image {rank+1}:** {doc['doc_name']} (Score: {score:.2f})\n"
136
+ except:
137
+ continue
138
+
139
+ # 4. GENERATION
140
+ system_prompt = (
141
+ "You are an expert financial analyst examining 3 pages of a report. "
142
+ "Your goal is to answer the user question using ONLY the provided images."
143
+ )
144
+
145
+ user_content = images_content + [{"type": "text", "text": f"\nUser Question: {query}"}]
146
+
147
+ messages = [
148
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
149
+ {"role": "user", "content": user_content}
150
+ ]
151
+
152
+ text_input = gen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
153
+ image_inputs, video_inputs = process_vision_info(messages)
154
+
155
+ inputs = gen_processor(
156
+ text=[text_input],
157
+ images=image_inputs,
158
+ padding=True,
159
+ return_tensors="pt",
160
+ ).to(gen_model.device)
161
+
162
+ generated_ids = gen_model.generate(**inputs, max_new_tokens=768)
163
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
164
+ response = gen_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
165
 
166
+ return gallery_preview, meta_info, response
 
 
 
167
 
168
+ # --- 5. UI ---
169
+ with gr.Blocks(title="RAG Finance") as demo:
170
+ gr.Markdown("# 🚀 RAG Finance (Version Originale Instable)")
171
 
172
  with gr.Row():
173
+ query_input = gr.Textbox(label="Question")
174
+ submit_btn = gr.Button("Analyser", variant="primary")
175
+
176
  with gr.Row():
177
+ output_gallery = gr.Gallery(label="Pages")
178
+ output_meta = gr.Markdown(label="Sources")
179
+ output_text = gr.Markdown(label="Réponse")
180
+
181
+ submit_btn.click(retrieve_and_answer, inputs=query_input, outputs=[output_gallery, output_meta, output_text])
 
 
182
 
183
  if __name__ == "__main__":
184
  demo.launch()