VincentGOURBIN commited on
Commit
3306e7f
·
verified ·
1 Parent(s): 0563c4d

Upload step03_chatbot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. step03_chatbot.py +27 -23
step03_chatbot.py CHANGED
@@ -9,6 +9,18 @@ import json
9
  import numpy as np
10
  import gradio as gr
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Désactiver le warning tokenizers sur ZeroGPU
13
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
  from gradio import ChatMessage
@@ -156,7 +168,7 @@ class Qwen3Reranker:
156
  Reranker utilisant Qwen3-Reranker-4B pour améliorer la pertinence des résultats de recherche
157
  """
158
 
159
- def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-4B", use_flash_attention: bool = True):
160
  """
161
  Initialise le reranker Qwen3
162
 
@@ -165,7 +177,7 @@ class Qwen3Reranker:
165
  use_flash_attention: Utiliser Flash Attention 2 si disponible (auto-désactivé sur Mac)
166
  """
167
  self.model_name = model_name
168
- self.use_flash_attention = use_flash_attention
169
 
170
  # Détection de l'environnement
171
  self.is_mps = torch.backends.mps.is_available()
@@ -258,15 +270,15 @@ class Qwen3Reranker:
258
  if self.is_mps:
259
  self.device = torch.device("mps")
260
  self.model = self.model.to(self.device)
261
- elif self.is_cuda and not os.getenv("SPACE_ID"):
262
- # Utiliser CUDA seulement si pas sur ZeroGPU Spaces
263
  if hasattr(self.model, 'device'):
264
  self.device = next(self.model.parameters()).device
265
  else:
266
  self.device = torch.device("cuda")
267
  self.model = self.model.to(self.device)
268
  else:
269
- # Forcer CPU sur ZeroGPU pour éviter l'erreur CUDA init
270
  self.device = torch.device("cpu")
271
  self.model = self.model.to(self.device)
272
 
@@ -407,7 +419,7 @@ class GenericRAGChatbot:
407
  generation_model: str = "Qwen/Qwen3-4B-Instruct-2507",
408
  initial_k: int = 20,
409
  final_k: int = 3,
410
- use_flash_attention: bool = True,
411
  use_reranker: bool = True):
412
  """
413
  Initialise le système RAG générique
@@ -422,7 +434,7 @@ class GenericRAGChatbot:
422
  self.generation_model_name = generation_model
423
  self.initial_k = initial_k
424
  self.final_k = final_k
425
- self.use_flash_attention = use_flash_attention
426
  self.use_reranker = use_reranker
427
 
428
  # Détection de l'environnement (local + ZeroGPU)
@@ -571,18 +583,7 @@ class GenericRAGChatbot:
571
  try:
572
  from sentence_transformers import SentenceTransformer
573
 
574
- if os.getenv("SPACE_ID"):
575
- print(" - Configuration ZeroGPU optimisée")
576
- # Sur ZeroGPU, utiliser float16 et device auto pour les performances
577
- self.embedding_model = SentenceTransformer(
578
- self.config.embedding_model,
579
- model_kwargs={
580
- "torch_dtype": torch.float16,
581
- "device_map": "auto"
582
- },
583
- tokenizer_kwargs={"padding_side": "left"}
584
- )
585
- elif self.use_flash_attention and self.is_cuda:
586
  print(" - Configuration avec Flash Attention 2 activée (CUDA)")
587
  try:
588
  self.embedding_model = SentenceTransformer(
@@ -722,7 +723,7 @@ class GenericRAGChatbot:
722
  except:
723
  return 0.0
724
 
725
- @spaces.GPU(duration=120) # ZeroGPU: GPU nécessaire pour embedding
726
  def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
727
  """
728
  Recherche avancée avec reranking en deux étapes
@@ -819,7 +820,7 @@ class GenericRAGChatbot:
819
 
820
  return final_results
821
 
822
- @spaces.GPU(duration=120) # ZeroGPU: GPU seulement pour la génération
823
  def generate_response_stream(self, query: str, context: str, history: List = None):
824
  """
825
  Génère une réponse streamée basée sur le contexte et l'historique
@@ -913,6 +914,7 @@ Instructions importantes:
913
  except Exception as e:
914
  yield f"❌ Erreur lors de la génération: {str(e)}"
915
 
 
916
  def generate_response(self, query: str, context: str, history: List = None) -> str:
917
  """
918
  Génère une réponse basée sur le contexte et l'historique
@@ -997,6 +999,7 @@ Réponds à cette question en te basant sur le contexte fourni."""
997
  print(f"❌ Erreur lors de la génération: {e}")
998
  return f"❌ Erreur lors de la génération de la réponse: {str(e)}"
999
 
 
1000
  def stream_response_with_tools(self, query: str, history, top_k: int = None, use_reranking: bool = None):
1001
  """
1002
  Génère une réponse streamée avec affichage visuel des tools et reranking Qwen3
@@ -1156,7 +1159,7 @@ def _create_rag_system():
1156
  if is_zerogpu:
1157
  default_config = {
1158
  'generation_model': "Qwen/Qwen3-4B-Instruct-2507", # Modèle qui fonctionne sur ZeroGPU
1159
- 'use_flash_attention': True, # ZeroGPU supporte Flash Attention
1160
  'use_reranker': True, # GPU puissant, reranking activé
1161
  'initial_k': 20, # Même config que local
1162
  'final_k': 5 # Plus de documents finaux
@@ -1170,7 +1173,7 @@ def _create_rag_system():
1170
  }
1171
  else:
1172
  default_config = {
1173
- 'use_flash_attention': is_cuda, # Flash Attention seulement sur CUDA
1174
  'use_reranker': True, # Reranking par défaut
1175
  'initial_k': 20, # Candidats pour la première étape
1176
  'final_k': 3 # Documents finaux par défaut
@@ -1205,6 +1208,7 @@ def _ensure_chatmessages(history):
1205
  return result
1206
 
1207
 
 
1208
  def chat_with_generic_rag(message, history, top_k, use_reranking):
1209
  """
1210
  Interface entre Gradio et le système RAG générique avec contrôles avancés.
 
9
  import numpy as np
10
  import gradio as gr
11
 
12
+ # Import spaces pour ZeroGPU compatibility
13
+ try:
14
+ import spaces
15
+ except ImportError:
16
+ # Fallback pour environnements non-ZeroGPU
17
+ class spaces:
18
+ @staticmethod
19
+ def GPU(duration=60):
20
+ def decorator(func):
21
+ return func
22
+ return decorator
23
+
24
  # Désactiver le warning tokenizers sur ZeroGPU
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
  from gradio import ChatMessage
 
168
  Reranker utilisant Qwen3-Reranker-4B pour améliorer la pertinence des résultats de recherche
169
  """
170
 
171
+ def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-4B", use_flash_attention: bool = False):
172
  """
173
  Initialise le reranker Qwen3
174
 
 
177
  use_flash_attention: Utiliser Flash Attention 2 si disponible (auto-désactivé sur Mac)
178
  """
179
  self.model_name = model_name
180
+ self.use_flash_attention = False # Désactivé pour éviter les problèmes
181
 
182
  # Détection de l'environnement
183
  self.is_mps = torch.backends.mps.is_available()
 
270
  if self.is_mps:
271
  self.device = torch.device("mps")
272
  self.model = self.model.to(self.device)
273
+ elif self.is_cuda:
274
+ # Utiliser CUDA si disponible
275
  if hasattr(self.model, 'device'):
276
  self.device = next(self.model.parameters()).device
277
  else:
278
  self.device = torch.device("cuda")
279
  self.model = self.model.to(self.device)
280
  else:
281
+ # Fallback CPU
282
  self.device = torch.device("cpu")
283
  self.model = self.model.to(self.device)
284
 
 
419
  generation_model: str = "Qwen/Qwen3-4B-Instruct-2507",
420
  initial_k: int = 20,
421
  final_k: int = 3,
422
+ use_flash_attention: bool = False,
423
  use_reranker: bool = True):
424
  """
425
  Initialise le système RAG générique
 
434
  self.generation_model_name = generation_model
435
  self.initial_k = initial_k
436
  self.final_k = final_k
437
+ self.use_flash_attention = False # Désactivé pour éviter les problèmes
438
  self.use_reranker = use_reranker
439
 
440
  # Détection de l'environnement (local + ZeroGPU)
 
583
  try:
584
  from sentence_transformers import SentenceTransformer
585
 
586
+ if self.use_flash_attention and self.is_cuda:
 
 
 
 
 
 
 
 
 
 
 
587
  print(" - Configuration avec Flash Attention 2 activée (CUDA)")
588
  try:
589
  self.embedding_model = SentenceTransformer(
 
723
  except:
724
  return 0.0
725
 
726
+ @spaces.GPU(duration=120)
727
  def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
728
  """
729
  Recherche avancée avec reranking en deux étapes
 
820
 
821
  return final_results
822
 
823
+ @spaces.GPU(duration=180)
824
  def generate_response_stream(self, query: str, context: str, history: List = None):
825
  """
826
  Génère une réponse streamée basée sur le contexte et l'historique
 
914
  except Exception as e:
915
  yield f"❌ Erreur lors de la génération: {str(e)}"
916
 
917
+ @spaces.GPU(duration=180)
918
  def generate_response(self, query: str, context: str, history: List = None) -> str:
919
  """
920
  Génère une réponse basée sur le contexte et l'historique
 
999
  print(f"❌ Erreur lors de la génération: {e}")
1000
  return f"❌ Erreur lors de la génération de la réponse: {str(e)}"
1001
 
1002
+ @spaces.GPU(duration=300) # Durée plus longue car combine search + generation
1003
  def stream_response_with_tools(self, query: str, history, top_k: int = None, use_reranking: bool = None):
1004
  """
1005
  Génère une réponse streamée avec affichage visuel des tools et reranking Qwen3
 
1159
  if is_zerogpu:
1160
  default_config = {
1161
  'generation_model': "Qwen/Qwen3-4B-Instruct-2507", # Modèle qui fonctionne sur ZeroGPU
1162
+ 'use_flash_attention': False, # Désactivé pour stabilité
1163
  'use_reranker': True, # GPU puissant, reranking activé
1164
  'initial_k': 20, # Même config que local
1165
  'final_k': 5 # Plus de documents finaux
 
1173
  }
1174
  else:
1175
  default_config = {
1176
+ 'use_flash_attention': False, # Désactivé pour stabilité
1177
  'use_reranker': True, # Reranking par défaut
1178
  'initial_k': 20, # Candidats pour la première étape
1179
  'final_k': 3 # Documents finaux par défaut
 
1208
  return result
1209
 
1210
 
1211
+ @spaces.GPU(duration=300) # Fonction principale de chat
1212
  def chat_with_generic_rag(message, history, top_k, use_reranking):
1213
  """
1214
  Interface entre Gradio et le système RAG générique avec contrôles avancés.