# app.py # ============================================================================== # 1. IMPORTS # ============================================================================== import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr import os # Utile pour gérer les chemins si nécessaire. # ============================================================================== # 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space) # ============================================================================== MODEL_NAME = "Clem27AI/Melta27" # Détermination du DEVICE : Priorité au GPU si disponible (pour un Space), sinon CPU. # Note : Pour les petits modèles sur Hugging Face Spaces, le CPU est souvent la seule option si pas de GPU spécifié. DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...") # Chargement du tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Assurez-vous d'avoir un token de padding/fin de séquence if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Chargement du modèle. On utilise device_map="auto" pour laisser Transformers # gérer l'emplacement optimal et le déchargement/chargement si nécessaire. try: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, # float32 est souvent suffisant pour le CPU device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement ) # Si device_map n'a pas tout géré (ex: model.to(DEVICE) pour un CPU explicite) # Ligne commentée car device_map="auto" est la meilleure pratique. # model.to(DEVICE) print("Modèle chargé avec succès.") except Exception as e: print(f"Erreur lors du chargement du modèle : {e}") exit() # ============================================================================== # 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT # ============================================================================== def format_prompt(history, message): """ Formate la conversation complète pour le modèle SLM dans le format : ### Instruction: [HISTORIQUE DE LA CONVERSATION] [NOUVELLE QUESTION] ### Response: """ # 1. Construire l'historique complet pour le contexte full_history = "" # history est une liste de paires [utilisateur, bot] for user_msg, bot_msg in history: # On assume un format simple de Questions/Réponses dans l'historique # Note: L'ajout de "\n" avant le bot_msg est important si le modèle # a été entraîné avec ce type de saut de ligne. full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n" # 2. Ajouter la nouvelle question de l'utilisateur full_prompt = ( f"{full_history}" # L'historique des tours précédents f"### Instruction: {message}\n\n" # La nouvelle question f"### Response:" # Le modèle doit continuer à partir d'ici ) return full_prompt.strip() def generate_response(message, history): """ Fonction principale appelée par l'interface Gradio Chat. """ # 1. Formatage du prompt complet avec l'historique prompt = format_prompt(history, message) # 2. Tokenization et placement sur le device # On ajoute la ligne 'prompt' pour que le modèle commence à générer après. inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Utiliser model.device pour la cohérence # 3. Génération with torch.no_grad(): output_tokens = model.generate( **inputs, max_new_tokens=100, # Augmenté la longueur pour des réponses plus complètes do_sample=True, temperature=0.7, # Augmenté légèrement pour plus de créativité top_k=50, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, use_cache=True, ) # 4. Décodage et nettoyage # Le tenseur 'output_tokens' contient [prompt_tokens | generated_tokens] # On décode tout le résultat full_generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True) # Pour s'assurer de n'avoir que la réponse du bot, on retire le prompt initial # qui a servi d'entrée à la fonction 'generate'. # On doit s'assurer que l'objet 'prompt' n'a pas de caractères spéciaux # ou de multiples espaces qui empêcheraient la correspondance exacte. # Mieux vaut chercher l'indice de la dernière occurrence du marqueur '### Response:' assistant_prefix = "### Response:" # On cherche le point de départ de la réponse générée (le prompt se termine par assistant_prefix) try: # On ajoute la longueur de l'assistant_prefix pour commencer JUSTE après start_index = full_generated_text.rindex(assistant_prefix) + len(assistant_prefix) clean_response = full_generated_text[start_index:].strip() # Le modèle peut parfois continuer et générer le prochain marqueur de prompt. # On arrête la réponse avant cela. if '### Instruction:' in clean_response: clean_response = clean_response.split('### Instruction:')[0].strip() except ValueError: # Fallback si le format n'est pas trouvé (ce qui est rare) print("Avertissement: Format '### Response:' non trouvé dans la sortie.") clean_response = full_generated_text.replace(prompt, "").strip() return clean_response # ============================================================================== # 4. CRÉATION DE L'INTERFACE GRADIO # ============================================================================== # Le composant `gr.ChatInterface` est le plus simple et le plus adapté. chat_interface = gr.ChatInterface( fn=generate_response, chatbot=gr.Chatbot(height=500), # Taille du champ de chat textbox=gr.Textbox(placeholder="Posez votre question à Melta...", container=False, scale=7), title=f"Chat avec {MODEL_NAME} (SLM)", description="Interface de discussion pour le SLM Melta. L'inférence est effectuée sur CPU ou GPU si disponible.", # On ne met PLUS 'theme' ici ! ) # Lancement de l'application if __name__ == "__main__": # --- LA CORRECTION EST ICI --- # L'argument 'theme' doit être passé à .launch() pour personnaliser le style # L'argument 'server_name' est souvent utile pour les déploiements (ex: '0.0.0.0' pour HF Spaces) chat_interface.launch( theme="soft", # <--- CORRIGÉ server_name="0.0.0.0", share=False )