Spaces:
Running
on
Zero
Running
on
Zero
Upload step03_chatbot.py with huggingface_hub
Browse files- step03_chatbot.py +43 -40
step03_chatbot.py
CHANGED
|
@@ -571,7 +571,18 @@ class GenericRAGChatbot:
|
|
| 571 |
try:
|
| 572 |
from sentence_transformers import SentenceTransformer
|
| 573 |
|
| 574 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
print(" - Configuration avec Flash Attention 2 activée (CUDA)")
|
| 576 |
try:
|
| 577 |
self.embedding_model = SentenceTransformer(
|
|
@@ -711,6 +722,7 @@ class GenericRAGChatbot:
|
|
| 711 |
except:
|
| 712 |
return 0.0
|
| 713 |
|
|
|
|
| 714 |
def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
|
| 715 |
"""
|
| 716 |
Recherche avancée avec reranking en deux étapes
|
|
@@ -724,10 +736,13 @@ class GenericRAGChatbot:
|
|
| 724 |
# Les modèles d'embedding fonctionnent bien sur CPU sur ZeroGPU
|
| 725 |
|
| 726 |
# Étape 1: Recherche par embedding avec FAISS
|
|
|
|
| 727 |
if hasattr(self.embedding_model, 'prompts') and 'query' in self.embedding_model.prompts:
|
| 728 |
-
query_embedding = self.embedding_model.encode([query], prompt_name="query")[0]
|
| 729 |
else:
|
| 730 |
-
query_embedding = self.embedding_model.encode([query])[0]
|
|
|
|
|
|
|
| 731 |
|
| 732 |
# Recherche dans l'index FAISS
|
| 733 |
query_vector = query_embedding.reshape(1, -1).astype('float32')
|
|
@@ -842,27 +857,19 @@ Instructions importantes:
|
|
| 842 |
messages.append({"role": "user", "content": user_message})
|
| 843 |
|
| 844 |
try:
|
| 845 |
-
#
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
formatted_messages.append(f"<|im_start|>user\n{msg['content']}<|im_end|>")
|
| 852 |
-
elif msg["role"] == "assistant":
|
| 853 |
-
formatted_messages.append(f"<|im_start|>assistant\n{msg['content']}<|im_end|>")
|
| 854 |
-
|
| 855 |
-
# Ajouter le prompt de génération
|
| 856 |
-
formatted_messages.append("<|im_start|>assistant\n")
|
| 857 |
-
formatted_prompt = "\n".join(formatted_messages)
|
| 858 |
|
| 859 |
# Tokenisation
|
| 860 |
inputs = self.generation_tokenizer(
|
| 861 |
formatted_prompt,
|
| 862 |
return_tensors="pt",
|
| 863 |
truncation=True,
|
| 864 |
-
max_length=4096
|
| 865 |
-
padding=True
|
| 866 |
)
|
| 867 |
|
| 868 |
# Déplacement vers le device
|
|
@@ -883,8 +890,10 @@ Instructions importantes:
|
|
| 883 |
"input_ids": inputs["input_ids"],
|
| 884 |
"attention_mask": inputs["attention_mask"],
|
| 885 |
"streamer": streamer,
|
| 886 |
-
"max_new_tokens":
|
| 887 |
-
"temperature": 0.7,
|
|
|
|
|
|
|
| 888 |
"do_sample": True,
|
| 889 |
"pad_token_id": self.generation_tokenizer.pad_token_id,
|
| 890 |
"eos_token_id": self.generation_tokenizer.eos_token_id,
|
|
@@ -943,39 +952,33 @@ Réponds à cette question en te basant sur le contexte fourni."""
|
|
| 943 |
|
| 944 |
# Formatage pour le modèle
|
| 945 |
try:
|
| 946 |
-
#
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
formatted_messages.append(f"<|im_start|>assistant\n{msg['content']}<|im_end|>")
|
| 955 |
-
|
| 956 |
-
# Ajouter le prompt de génération
|
| 957 |
-
formatted_messages.append("<|im_start|>assistant\n")
|
| 958 |
-
formatted_prompt = "\n".join(formatted_messages)
|
| 959 |
-
|
| 960 |
-
# Tokenisation avec padding et attention mask appropriés
|
| 961 |
inputs = self.generation_tokenizer(
|
| 962 |
formatted_prompt,
|
| 963 |
return_tensors="pt",
|
| 964 |
truncation=True,
|
| 965 |
-
max_length=4096
|
| 966 |
-
padding=True
|
| 967 |
)
|
| 968 |
|
| 969 |
# Déplacement vers le device
|
| 970 |
inputs = {k: v.to(self.generation_device) for k, v in inputs.items()}
|
| 971 |
|
| 972 |
-
# Génération avec paramètres
|
| 973 |
with torch.no_grad():
|
| 974 |
outputs = self.generation_model.generate(
|
| 975 |
input_ids=inputs["input_ids"],
|
| 976 |
attention_mask=inputs["attention_mask"],
|
| 977 |
-
max_new_tokens=
|
| 978 |
-
temperature=0.7,
|
|
|
|
|
|
|
| 979 |
do_sample=True,
|
| 980 |
pad_token_id=self.generation_tokenizer.pad_token_id,
|
| 981 |
eos_token_id=self.generation_tokenizer.eos_token_id,
|
|
|
|
| 571 |
try:
|
| 572 |
from sentence_transformers import SentenceTransformer
|
| 573 |
|
| 574 |
+
if os.getenv("SPACE_ID"):
|
| 575 |
+
print(" - Configuration ZeroGPU optimisée")
|
| 576 |
+
# Sur ZeroGPU, utiliser float16 et device auto pour les performances
|
| 577 |
+
self.embedding_model = SentenceTransformer(
|
| 578 |
+
self.config.embedding_model,
|
| 579 |
+
model_kwargs={
|
| 580 |
+
"torch_dtype": torch.float16,
|
| 581 |
+
"device_map": "auto"
|
| 582 |
+
},
|
| 583 |
+
tokenizer_kwargs={"padding_side": "left"}
|
| 584 |
+
)
|
| 585 |
+
elif self.use_flash_attention and self.is_cuda:
|
| 586 |
print(" - Configuration avec Flash Attention 2 activée (CUDA)")
|
| 587 |
try:
|
| 588 |
self.embedding_model = SentenceTransformer(
|
|
|
|
| 722 |
except:
|
| 723 |
return 0.0
|
| 724 |
|
| 725 |
+
@spaces.GPU(duration=120) # ZeroGPU: GPU nécessaire pour embedding
|
| 726 |
def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
|
| 727 |
"""
|
| 728 |
Recherche avancée avec reranking en deux étapes
|
|
|
|
| 736 |
# Les modèles d'embedding fonctionnent bien sur CPU sur ZeroGPU
|
| 737 |
|
| 738 |
# Étape 1: Recherche par embedding avec FAISS
|
| 739 |
+
print(" 🎯 Calcul de l'embedding de la requête...")
|
| 740 |
if hasattr(self.embedding_model, 'prompts') and 'query' in self.embedding_model.prompts:
|
| 741 |
+
query_embedding = self.embedding_model.encode([query], prompt_name="query", show_progress_bar=False)[0]
|
| 742 |
else:
|
| 743 |
+
query_embedding = self.embedding_model.encode([query], show_progress_bar=False)[0]
|
| 744 |
+
|
| 745 |
+
print(f" 📐 Embedding calculé: shape={query_embedding.shape}, norm={np.linalg.norm(query_embedding):.3f}")
|
| 746 |
|
| 747 |
# Recherche dans l'index FAISS
|
| 748 |
query_vector = query_embedding.reshape(1, -1).astype('float32')
|
|
|
|
| 857 |
messages.append({"role": "user", "content": user_message})
|
| 858 |
|
| 859 |
try:
|
| 860 |
+
# Utiliser le template officiel Qwen3 (documentation officielle)
|
| 861 |
+
formatted_prompt = self.generation_tokenizer.apply_chat_template(
|
| 862 |
+
messages,
|
| 863 |
+
tokenize=False,
|
| 864 |
+
add_generation_prompt=True
|
| 865 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
|
| 867 |
# Tokenisation
|
| 868 |
inputs = self.generation_tokenizer(
|
| 869 |
formatted_prompt,
|
| 870 |
return_tensors="pt",
|
| 871 |
truncation=True,
|
| 872 |
+
max_length=4096
|
|
|
|
| 873 |
)
|
| 874 |
|
| 875 |
# Déplacement vers le device
|
|
|
|
| 890 |
"input_ids": inputs["input_ids"],
|
| 891 |
"attention_mask": inputs["attention_mask"],
|
| 892 |
"streamer": streamer,
|
| 893 |
+
"max_new_tokens": 1024, # Recommandation officielle
|
| 894 |
+
"temperature": 0.7, # Recommandation officielle
|
| 895 |
+
"top_p": 0.8, # Recommandation officielle
|
| 896 |
+
"top_k": 20, # Recommandation officielle
|
| 897 |
"do_sample": True,
|
| 898 |
"pad_token_id": self.generation_tokenizer.pad_token_id,
|
| 899 |
"eos_token_id": self.generation_tokenizer.eos_token_id,
|
|
|
|
| 952 |
|
| 953 |
# Formatage pour le modèle
|
| 954 |
try:
|
| 955 |
+
# Utiliser le template officiel Qwen3 (documentation officielle)
|
| 956 |
+
formatted_prompt = self.generation_tokenizer.apply_chat_template(
|
| 957 |
+
messages,
|
| 958 |
+
tokenize=False,
|
| 959 |
+
add_generation_prompt=True
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# Tokenisation avec les bonnes options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 963 |
inputs = self.generation_tokenizer(
|
| 964 |
formatted_prompt,
|
| 965 |
return_tensors="pt",
|
| 966 |
truncation=True,
|
| 967 |
+
max_length=4096
|
|
|
|
| 968 |
)
|
| 969 |
|
| 970 |
# Déplacement vers le device
|
| 971 |
inputs = {k: v.to(self.generation_device) for k, v in inputs.items()}
|
| 972 |
|
| 973 |
+
# Génération avec paramètres officiels Qwen3
|
| 974 |
with torch.no_grad():
|
| 975 |
outputs = self.generation_model.generate(
|
| 976 |
input_ids=inputs["input_ids"],
|
| 977 |
attention_mask=inputs["attention_mask"],
|
| 978 |
+
max_new_tokens=1024, # Recommandation officielle
|
| 979 |
+
temperature=0.7, # Recommandation officielle
|
| 980 |
+
top_p=0.8, # Recommandation officielle
|
| 981 |
+
top_k=20, # Recommandation officielle
|
| 982 |
do_sample=True,
|
| 983 |
pad_token_id=self.generation_tokenizer.pad_token_id,
|
| 984 |
eos_token_id=self.generation_tokenizer.eos_token_id,
|