Spaces:
Sleeping
Sleeping
File size: 6,751 Bytes
3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 3dc8757 71ddbc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# app.py
# ==============================================================================
# 1. IMPORTS
# ==============================================================================
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import os # Utile pour gérer les chemins si nécessaire.
# ==============================================================================
# 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space)
# ==============================================================================
MODEL_NAME = "Clem27AI/Melta27"
# Détermination du DEVICE : Priorité au GPU si disponible (pour un Space), sinon CPU.
# Note : Pour les petits modèles sur Hugging Face Spaces, le CPU est souvent la seule option si pas de GPU spécifié.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...")
# Chargement du tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Assurez-vous d'avoir un token de padding/fin de séquence
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Chargement du modèle. On utilise device_map="auto" pour laisser Transformers
# gérer l'emplacement optimal et le déchargement/chargement si nécessaire.
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32, # float32 est souvent suffisant pour le CPU
device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement
)
# Si device_map n'a pas tout géré (ex: model.to(DEVICE) pour un CPU explicite)
# Ligne commentée car device_map="auto" est la meilleure pratique.
# model.to(DEVICE)
print("Modèle chargé avec succès.")
except Exception as e:
print(f"Erreur lors du chargement du modèle : {e}")
exit()
# ==============================================================================
# 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT
# ==============================================================================
def format_prompt(history, message):
"""
Formate la conversation complète pour le modèle SLM dans le format :
### Instruction:
[HISTORIQUE DE LA CONVERSATION]
[NOUVELLE QUESTION]
### Response:
"""
# 1. Construire l'historique complet pour le contexte
full_history = ""
# history est une liste de paires [utilisateur, bot]
for user_msg, bot_msg in history:
# On assume un format simple de Questions/Réponses dans l'historique
# Note: L'ajout de "\n" avant le bot_msg est important si le modèle
# a été entraîné avec ce type de saut de ligne.
full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n"
# 2. Ajouter la nouvelle question de l'utilisateur
full_prompt = (
f"{full_history}" # L'historique des tours précédents
f"### Instruction: {message}\n\n" # La nouvelle question
f"### Response:" # Le modèle doit continuer à partir d'ici
)
return full_prompt.strip()
def generate_response(message, history):
"""
Fonction principale appelée par l'interface Gradio Chat.
"""
# 1. Formatage du prompt complet avec l'historique
prompt = format_prompt(history, message)
# 2. Tokenization et placement sur le device
# On ajoute la ligne 'prompt' pour que le modèle commence à générer après.
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Utiliser model.device pour la cohérence
# 3. Génération
with torch.no_grad():
output_tokens = model.generate(
**inputs,
max_new_tokens=100, # Augmenté la longueur pour des réponses plus complètes
do_sample=True,
temperature=0.7, # Augmenté légèrement pour plus de créativité
top_k=50,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
# 4. Décodage et nettoyage
# Le tenseur 'output_tokens' contient [prompt_tokens | generated_tokens]
# On décode tout le résultat
full_generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Pour s'assurer de n'avoir que la réponse du bot, on retire le prompt initial
# qui a servi d'entrée à la fonction 'generate'.
# On doit s'assurer que l'objet 'prompt' n'a pas de caractères spéciaux
# ou de multiples espaces qui empêcheraient la correspondance exacte.
# Mieux vaut chercher l'indice de la dernière occurrence du marqueur '### Response:'
assistant_prefix = "### Response:"
# On cherche le point de départ de la réponse générée (le prompt se termine par assistant_prefix)
try:
# On ajoute la longueur de l'assistant_prefix pour commencer JUSTE après
start_index = full_generated_text.rindex(assistant_prefix) + len(assistant_prefix)
clean_response = full_generated_text[start_index:].strip()
# Le modèle peut parfois continuer et générer le prochain marqueur de prompt.
# On arrête la réponse avant cela.
if '### Instruction:' in clean_response:
clean_response = clean_response.split('### Instruction:')[0].strip()
except ValueError:
# Fallback si le format n'est pas trouvé (ce qui est rare)
print("Avertissement: Format '### Response:' non trouvé dans la sortie.")
clean_response = full_generated_text.replace(prompt, "").strip()
return clean_response
# ==============================================================================
# 4. CRÉATION DE L'INTERFACE GRADIO
# ==============================================================================
# Le composant `gr.ChatInterface` est le plus simple et le plus adapté.
chat_interface = gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(height=500), # Taille du champ de chat
textbox=gr.Textbox(placeholder="Posez votre question à Melta...", container=False, scale=7),
title=f"Chat avec {MODEL_NAME} (SLM)",
description="Interface de discussion pour le SLM Melta. L'inférence est effectuée sur CPU ou GPU si disponible.",
# On ne met PLUS 'theme' ici !
)
# Lancement de l'application
if __name__ == "__main__":
# --- LA CORRECTION EST ICI ---
# L'argument 'theme' doit être passé à .launch() pour personnaliser le style
# L'argument 'server_name' est souvent utile pour les déploiements (ex: '0.0.0.0' pour HF Spaces)
chat_interface.launch(
theme="soft", # <--- CORRIGÉ
server_name="0.0.0.0",
share=False
) |