Melta-tchat / app.py
Clem27AI's picture
Update app.py
71ddbc0 verified
# app.py
# ==============================================================================
# 1. IMPORTS
# ==============================================================================
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import os # Utile pour gérer les chemins si nécessaire.
# ==============================================================================
# 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space)
# ==============================================================================
MODEL_NAME = "Clem27AI/Melta27"
# Détermination du DEVICE : Priorité au GPU si disponible (pour un Space), sinon CPU.
# Note : Pour les petits modèles sur Hugging Face Spaces, le CPU est souvent la seule option si pas de GPU spécifié.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...")
# Chargement du tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Assurez-vous d'avoir un token de padding/fin de séquence
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Chargement du modèle. On utilise device_map="auto" pour laisser Transformers
# gérer l'emplacement optimal et le déchargement/chargement si nécessaire.
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32, # float32 est souvent suffisant pour le CPU
device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement
)
# Si device_map n'a pas tout géré (ex: model.to(DEVICE) pour un CPU explicite)
# Ligne commentée car device_map="auto" est la meilleure pratique.
# model.to(DEVICE)
print("Modèle chargé avec succès.")
except Exception as e:
print(f"Erreur lors du chargement du modèle : {e}")
exit()
# ==============================================================================
# 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT
# ==============================================================================
def format_prompt(history, message):
"""
Formate la conversation complète pour le modèle SLM dans le format :
### Instruction:
[HISTORIQUE DE LA CONVERSATION]
[NOUVELLE QUESTION]
### Response:
"""
# 1. Construire l'historique complet pour le contexte
full_history = ""
# history est une liste de paires [utilisateur, bot]
for user_msg, bot_msg in history:
# On assume un format simple de Questions/Réponses dans l'historique
# Note: L'ajout de "\n" avant le bot_msg est important si le modèle
# a été entraîné avec ce type de saut de ligne.
full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n"
# 2. Ajouter la nouvelle question de l'utilisateur
full_prompt = (
f"{full_history}" # L'historique des tours précédents
f"### Instruction: {message}\n\n" # La nouvelle question
f"### Response:" # Le modèle doit continuer à partir d'ici
)
return full_prompt.strip()
def generate_response(message, history):
"""
Fonction principale appelée par l'interface Gradio Chat.
"""
# 1. Formatage du prompt complet avec l'historique
prompt = format_prompt(history, message)
# 2. Tokenization et placement sur le device
# On ajoute la ligne 'prompt' pour que le modèle commence à générer après.
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Utiliser model.device pour la cohérence
# 3. Génération
with torch.no_grad():
output_tokens = model.generate(
**inputs,
max_new_tokens=100, # Augmenté la longueur pour des réponses plus complètes
do_sample=True,
temperature=0.7, # Augmenté légèrement pour plus de créativité
top_k=50,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
# 4. Décodage et nettoyage
# Le tenseur 'output_tokens' contient [prompt_tokens | generated_tokens]
# On décode tout le résultat
full_generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Pour s'assurer de n'avoir que la réponse du bot, on retire le prompt initial
# qui a servi d'entrée à la fonction 'generate'.
# On doit s'assurer que l'objet 'prompt' n'a pas de caractères spéciaux
# ou de multiples espaces qui empêcheraient la correspondance exacte.
# Mieux vaut chercher l'indice de la dernière occurrence du marqueur '### Response:'
assistant_prefix = "### Response:"
# On cherche le point de départ de la réponse générée (le prompt se termine par assistant_prefix)
try:
# On ajoute la longueur de l'assistant_prefix pour commencer JUSTE après
start_index = full_generated_text.rindex(assistant_prefix) + len(assistant_prefix)
clean_response = full_generated_text[start_index:].strip()
# Le modèle peut parfois continuer et générer le prochain marqueur de prompt.
# On arrête la réponse avant cela.
if '### Instruction:' in clean_response:
clean_response = clean_response.split('### Instruction:')[0].strip()
except ValueError:
# Fallback si le format n'est pas trouvé (ce qui est rare)
print("Avertissement: Format '### Response:' non trouvé dans la sortie.")
clean_response = full_generated_text.replace(prompt, "").strip()
return clean_response
# ==============================================================================
# 4. CRÉATION DE L'INTERFACE GRADIO
# ==============================================================================
# Le composant `gr.ChatInterface` est le plus simple et le plus adapté.
chat_interface = gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(height=500), # Taille du champ de chat
textbox=gr.Textbox(placeholder="Posez votre question à Melta...", container=False, scale=7),
title=f"Chat avec {MODEL_NAME} (SLM)",
description="Interface de discussion pour le SLM Melta. L'inférence est effectuée sur CPU ou GPU si disponible.",
# On ne met PLUS 'theme' ici !
)
# Lancement de l'application
if __name__ == "__main__":
# --- LA CORRECTION EST ICI ---
# L'argument 'theme' doit être passé à .launch() pour personnaliser le style
# L'argument 'server_name' est souvent utile pour les déploiements (ex: '0.0.0.0' pour HF Spaces)
chat_interface.launch(
theme="soft", # <--- CORRIGÉ
server_name="0.0.0.0",
share=False
)