Spaces:
Running
on
Zero
Running
on
Zero
Upload step03_chatbot.py with huggingface_hub
Browse files- step03_chatbot.py +27 -23
step03_chatbot.py
CHANGED
|
@@ -9,6 +9,18 @@ import json
|
|
| 9 |
import numpy as np
|
| 10 |
import gradio as gr
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Désactiver le warning tokenizers sur ZeroGPU
|
| 13 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
from gradio import ChatMessage
|
|
@@ -156,7 +168,7 @@ class Qwen3Reranker:
|
|
| 156 |
Reranker utilisant Qwen3-Reranker-4B pour améliorer la pertinence des résultats de recherche
|
| 157 |
"""
|
| 158 |
|
| 159 |
-
def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-4B", use_flash_attention: bool =
|
| 160 |
"""
|
| 161 |
Initialise le reranker Qwen3
|
| 162 |
|
|
@@ -165,7 +177,7 @@ class Qwen3Reranker:
|
|
| 165 |
use_flash_attention: Utiliser Flash Attention 2 si disponible (auto-désactivé sur Mac)
|
| 166 |
"""
|
| 167 |
self.model_name = model_name
|
| 168 |
-
self.use_flash_attention =
|
| 169 |
|
| 170 |
# Détection de l'environnement
|
| 171 |
self.is_mps = torch.backends.mps.is_available()
|
|
@@ -258,15 +270,15 @@ class Qwen3Reranker:
|
|
| 258 |
if self.is_mps:
|
| 259 |
self.device = torch.device("mps")
|
| 260 |
self.model = self.model.to(self.device)
|
| 261 |
-
elif self.is_cuda
|
| 262 |
-
# Utiliser CUDA
|
| 263 |
if hasattr(self.model, 'device'):
|
| 264 |
self.device = next(self.model.parameters()).device
|
| 265 |
else:
|
| 266 |
self.device = torch.device("cuda")
|
| 267 |
self.model = self.model.to(self.device)
|
| 268 |
else:
|
| 269 |
-
#
|
| 270 |
self.device = torch.device("cpu")
|
| 271 |
self.model = self.model.to(self.device)
|
| 272 |
|
|
@@ -407,7 +419,7 @@ class GenericRAGChatbot:
|
|
| 407 |
generation_model: str = "Qwen/Qwen3-4B-Instruct-2507",
|
| 408 |
initial_k: int = 20,
|
| 409 |
final_k: int = 3,
|
| 410 |
-
use_flash_attention: bool =
|
| 411 |
use_reranker: bool = True):
|
| 412 |
"""
|
| 413 |
Initialise le système RAG générique
|
|
@@ -422,7 +434,7 @@ class GenericRAGChatbot:
|
|
| 422 |
self.generation_model_name = generation_model
|
| 423 |
self.initial_k = initial_k
|
| 424 |
self.final_k = final_k
|
| 425 |
-
self.use_flash_attention =
|
| 426 |
self.use_reranker = use_reranker
|
| 427 |
|
| 428 |
# Détection de l'environnement (local + ZeroGPU)
|
|
@@ -571,18 +583,7 @@ class GenericRAGChatbot:
|
|
| 571 |
try:
|
| 572 |
from sentence_transformers import SentenceTransformer
|
| 573 |
|
| 574 |
-
if
|
| 575 |
-
print(" - Configuration ZeroGPU optimisée")
|
| 576 |
-
# Sur ZeroGPU, utiliser float16 et device auto pour les performances
|
| 577 |
-
self.embedding_model = SentenceTransformer(
|
| 578 |
-
self.config.embedding_model,
|
| 579 |
-
model_kwargs={
|
| 580 |
-
"torch_dtype": torch.float16,
|
| 581 |
-
"device_map": "auto"
|
| 582 |
-
},
|
| 583 |
-
tokenizer_kwargs={"padding_side": "left"}
|
| 584 |
-
)
|
| 585 |
-
elif self.use_flash_attention and self.is_cuda:
|
| 586 |
print(" - Configuration avec Flash Attention 2 activée (CUDA)")
|
| 587 |
try:
|
| 588 |
self.embedding_model = SentenceTransformer(
|
|
@@ -722,7 +723,7 @@ class GenericRAGChatbot:
|
|
| 722 |
except:
|
| 723 |
return 0.0
|
| 724 |
|
| 725 |
-
@spaces.GPU(duration=120)
|
| 726 |
def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
|
| 727 |
"""
|
| 728 |
Recherche avancée avec reranking en deux étapes
|
|
@@ -819,7 +820,7 @@ class GenericRAGChatbot:
|
|
| 819 |
|
| 820 |
return final_results
|
| 821 |
|
| 822 |
-
@spaces.GPU(duration=
|
| 823 |
def generate_response_stream(self, query: str, context: str, history: List = None):
|
| 824 |
"""
|
| 825 |
Génère une réponse streamée basée sur le contexte et l'historique
|
|
@@ -913,6 +914,7 @@ Instructions importantes:
|
|
| 913 |
except Exception as e:
|
| 914 |
yield f"❌ Erreur lors de la génération: {str(e)}"
|
| 915 |
|
|
|
|
| 916 |
def generate_response(self, query: str, context: str, history: List = None) -> str:
|
| 917 |
"""
|
| 918 |
Génère une réponse basée sur le contexte et l'historique
|
|
@@ -997,6 +999,7 @@ Réponds à cette question en te basant sur le contexte fourni."""
|
|
| 997 |
print(f"❌ Erreur lors de la génération: {e}")
|
| 998 |
return f"❌ Erreur lors de la génération de la réponse: {str(e)}"
|
| 999 |
|
|
|
|
| 1000 |
def stream_response_with_tools(self, query: str, history, top_k: int = None, use_reranking: bool = None):
|
| 1001 |
"""
|
| 1002 |
Génère une réponse streamée avec affichage visuel des tools et reranking Qwen3
|
|
@@ -1156,7 +1159,7 @@ def _create_rag_system():
|
|
| 1156 |
if is_zerogpu:
|
| 1157 |
default_config = {
|
| 1158 |
'generation_model': "Qwen/Qwen3-4B-Instruct-2507", # Modèle qui fonctionne sur ZeroGPU
|
| 1159 |
-
'use_flash_attention':
|
| 1160 |
'use_reranker': True, # GPU puissant, reranking activé
|
| 1161 |
'initial_k': 20, # Même config que local
|
| 1162 |
'final_k': 5 # Plus de documents finaux
|
|
@@ -1170,7 +1173,7 @@ def _create_rag_system():
|
|
| 1170 |
}
|
| 1171 |
else:
|
| 1172 |
default_config = {
|
| 1173 |
-
'use_flash_attention':
|
| 1174 |
'use_reranker': True, # Reranking par défaut
|
| 1175 |
'initial_k': 20, # Candidats pour la première étape
|
| 1176 |
'final_k': 3 # Documents finaux par défaut
|
|
@@ -1205,6 +1208,7 @@ def _ensure_chatmessages(history):
|
|
| 1205 |
return result
|
| 1206 |
|
| 1207 |
|
|
|
|
| 1208 |
def chat_with_generic_rag(message, history, top_k, use_reranking):
|
| 1209 |
"""
|
| 1210 |
Interface entre Gradio et le système RAG générique avec contrôles avancés.
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
+
# Import spaces pour ZeroGPU compatibility
|
| 13 |
+
try:
|
| 14 |
+
import spaces
|
| 15 |
+
except ImportError:
|
| 16 |
+
# Fallback pour environnements non-ZeroGPU
|
| 17 |
+
class spaces:
|
| 18 |
+
@staticmethod
|
| 19 |
+
def GPU(duration=60):
|
| 20 |
+
def decorator(func):
|
| 21 |
+
return func
|
| 22 |
+
return decorator
|
| 23 |
+
|
| 24 |
# Désactiver le warning tokenizers sur ZeroGPU
|
| 25 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 26 |
from gradio import ChatMessage
|
|
|
|
| 168 |
Reranker utilisant Qwen3-Reranker-4B pour améliorer la pertinence des résultats de recherche
|
| 169 |
"""
|
| 170 |
|
| 171 |
+
def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-4B", use_flash_attention: bool = False):
|
| 172 |
"""
|
| 173 |
Initialise le reranker Qwen3
|
| 174 |
|
|
|
|
| 177 |
use_flash_attention: Utiliser Flash Attention 2 si disponible (auto-désactivé sur Mac)
|
| 178 |
"""
|
| 179 |
self.model_name = model_name
|
| 180 |
+
self.use_flash_attention = False # Désactivé pour éviter les problèmes
|
| 181 |
|
| 182 |
# Détection de l'environnement
|
| 183 |
self.is_mps = torch.backends.mps.is_available()
|
|
|
|
| 270 |
if self.is_mps:
|
| 271 |
self.device = torch.device("mps")
|
| 272 |
self.model = self.model.to(self.device)
|
| 273 |
+
elif self.is_cuda:
|
| 274 |
+
# Utiliser CUDA si disponible
|
| 275 |
if hasattr(self.model, 'device'):
|
| 276 |
self.device = next(self.model.parameters()).device
|
| 277 |
else:
|
| 278 |
self.device = torch.device("cuda")
|
| 279 |
self.model = self.model.to(self.device)
|
| 280 |
else:
|
| 281 |
+
# Fallback CPU
|
| 282 |
self.device = torch.device("cpu")
|
| 283 |
self.model = self.model.to(self.device)
|
| 284 |
|
|
|
|
| 419 |
generation_model: str = "Qwen/Qwen3-4B-Instruct-2507",
|
| 420 |
initial_k: int = 20,
|
| 421 |
final_k: int = 3,
|
| 422 |
+
use_flash_attention: bool = False,
|
| 423 |
use_reranker: bool = True):
|
| 424 |
"""
|
| 425 |
Initialise le système RAG générique
|
|
|
|
| 434 |
self.generation_model_name = generation_model
|
| 435 |
self.initial_k = initial_k
|
| 436 |
self.final_k = final_k
|
| 437 |
+
self.use_flash_attention = False # Désactivé pour éviter les problèmes
|
| 438 |
self.use_reranker = use_reranker
|
| 439 |
|
| 440 |
# Détection de l'environnement (local + ZeroGPU)
|
|
|
|
| 583 |
try:
|
| 584 |
from sentence_transformers import SentenceTransformer
|
| 585 |
|
| 586 |
+
if self.use_flash_attention and self.is_cuda:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
print(" - Configuration avec Flash Attention 2 activée (CUDA)")
|
| 588 |
try:
|
| 589 |
self.embedding_model = SentenceTransformer(
|
|
|
|
| 723 |
except:
|
| 724 |
return 0.0
|
| 725 |
|
| 726 |
+
@spaces.GPU(duration=120)
|
| 727 |
def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
|
| 728 |
"""
|
| 729 |
Recherche avancée avec reranking en deux étapes
|
|
|
|
| 820 |
|
| 821 |
return final_results
|
| 822 |
|
| 823 |
+
@spaces.GPU(duration=180)
|
| 824 |
def generate_response_stream(self, query: str, context: str, history: List = None):
|
| 825 |
"""
|
| 826 |
Génère une réponse streamée basée sur le contexte et l'historique
|
|
|
|
| 914 |
except Exception as e:
|
| 915 |
yield f"❌ Erreur lors de la génération: {str(e)}"
|
| 916 |
|
| 917 |
+
@spaces.GPU(duration=180)
|
| 918 |
def generate_response(self, query: str, context: str, history: List = None) -> str:
|
| 919 |
"""
|
| 920 |
Génère une réponse basée sur le contexte et l'historique
|
|
|
|
| 999 |
print(f"❌ Erreur lors de la génération: {e}")
|
| 1000 |
return f"❌ Erreur lors de la génération de la réponse: {str(e)}"
|
| 1001 |
|
| 1002 |
+
@spaces.GPU(duration=300) # Durée plus longue car combine search + generation
|
| 1003 |
def stream_response_with_tools(self, query: str, history, top_k: int = None, use_reranking: bool = None):
|
| 1004 |
"""
|
| 1005 |
Génère une réponse streamée avec affichage visuel des tools et reranking Qwen3
|
|
|
|
| 1159 |
if is_zerogpu:
|
| 1160 |
default_config = {
|
| 1161 |
'generation_model': "Qwen/Qwen3-4B-Instruct-2507", # Modèle qui fonctionne sur ZeroGPU
|
| 1162 |
+
'use_flash_attention': False, # Désactivé pour stabilité
|
| 1163 |
'use_reranker': True, # GPU puissant, reranking activé
|
| 1164 |
'initial_k': 20, # Même config que local
|
| 1165 |
'final_k': 5 # Plus de documents finaux
|
|
|
|
| 1173 |
}
|
| 1174 |
else:
|
| 1175 |
default_config = {
|
| 1176 |
+
'use_flash_attention': False, # Désactivé pour stabilité
|
| 1177 |
'use_reranker': True, # Reranking par défaut
|
| 1178 |
'initial_k': 20, # Candidats pour la première étape
|
| 1179 |
'final_k': 3 # Documents finaux par défaut
|
|
|
|
| 1208 |
return result
|
| 1209 |
|
| 1210 |
|
| 1211 |
+
@spaces.GPU(duration=300) # Fonction principale de chat
|
| 1212 |
def chat_with_generic_rag(message, history, top_k, use_reranking):
|
| 1213 |
"""
|
| 1214 |
Interface entre Gradio et le système RAG générique avec contrôles avancés.
|