Spaces:
Sleeping
Sleeping
| # app.py | |
| # ============================================================================== | |
| # 1. IMPORTS | |
| # ============================================================================== | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| import os # Utile pour gérer les chemins si nécessaire. | |
| # ============================================================================== | |
| # 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space) | |
| # ============================================================================== | |
| MODEL_NAME = "Clem27AI/Melta27" | |
| # Détermination du DEVICE : Priorité au GPU si disponible (pour un Space), sinon CPU. | |
| # Note : Pour les petits modèles sur Hugging Face Spaces, le CPU est souvent la seule option si pas de GPU spécifié. | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...") | |
| # Chargement du tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Assurez-vous d'avoir un token de padding/fin de séquence | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Chargement du modèle. On utilise device_map="auto" pour laisser Transformers | |
| # gérer l'emplacement optimal et le déchargement/chargement si nécessaire. | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, # float32 est souvent suffisant pour le CPU | |
| device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement | |
| ) | |
| # Si device_map n'a pas tout géré (ex: model.to(DEVICE) pour un CPU explicite) | |
| # Ligne commentée car device_map="auto" est la meilleure pratique. | |
| # model.to(DEVICE) | |
| print("Modèle chargé avec succès.") | |
| except Exception as e: | |
| print(f"Erreur lors du chargement du modèle : {e}") | |
| exit() | |
| # ============================================================================== | |
| # 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT | |
| # ============================================================================== | |
| def format_prompt(history, message): | |
| """ | |
| Formate la conversation complète pour le modèle SLM dans le format : | |
| ### Instruction: | |
| [HISTORIQUE DE LA CONVERSATION] | |
| [NOUVELLE QUESTION] | |
| ### Response: | |
| """ | |
| # 1. Construire l'historique complet pour le contexte | |
| full_history = "" | |
| # history est une liste de paires [utilisateur, bot] | |
| for user_msg, bot_msg in history: | |
| # On assume un format simple de Questions/Réponses dans l'historique | |
| # Note: L'ajout de "\n" avant le bot_msg est important si le modèle | |
| # a été entraîné avec ce type de saut de ligne. | |
| full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n" | |
| # 2. Ajouter la nouvelle question de l'utilisateur | |
| full_prompt = ( | |
| f"{full_history}" # L'historique des tours précédents | |
| f"### Instruction: {message}\n\n" # La nouvelle question | |
| f"### Response:" # Le modèle doit continuer à partir d'ici | |
| ) | |
| return full_prompt.strip() | |
| def generate_response(message, history): | |
| """ | |
| Fonction principale appelée par l'interface Gradio Chat. | |
| """ | |
| # 1. Formatage du prompt complet avec l'historique | |
| prompt = format_prompt(history, message) | |
| # 2. Tokenization et placement sur le device | |
| # On ajoute la ligne 'prompt' pour que le modèle commence à générer après. | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Utiliser model.device pour la cohérence | |
| # 3. Génération | |
| with torch.no_grad(): | |
| output_tokens = model.generate( | |
| **inputs, | |
| max_new_tokens=100, # Augmenté la longueur pour des réponses plus complètes | |
| do_sample=True, | |
| temperature=0.7, # Augmenté légèrement pour plus de créativité | |
| top_k=50, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=True, | |
| ) | |
| # 4. Décodage et nettoyage | |
| # Le tenseur 'output_tokens' contient [prompt_tokens | generated_tokens] | |
| # On décode tout le résultat | |
| full_generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True) | |
| # Pour s'assurer de n'avoir que la réponse du bot, on retire le prompt initial | |
| # qui a servi d'entrée à la fonction 'generate'. | |
| # On doit s'assurer que l'objet 'prompt' n'a pas de caractères spéciaux | |
| # ou de multiples espaces qui empêcheraient la correspondance exacte. | |
| # Mieux vaut chercher l'indice de la dernière occurrence du marqueur '### Response:' | |
| assistant_prefix = "### Response:" | |
| # On cherche le point de départ de la réponse générée (le prompt se termine par assistant_prefix) | |
| try: | |
| # On ajoute la longueur de l'assistant_prefix pour commencer JUSTE après | |
| start_index = full_generated_text.rindex(assistant_prefix) + len(assistant_prefix) | |
| clean_response = full_generated_text[start_index:].strip() | |
| # Le modèle peut parfois continuer et générer le prochain marqueur de prompt. | |
| # On arrête la réponse avant cela. | |
| if '### Instruction:' in clean_response: | |
| clean_response = clean_response.split('### Instruction:')[0].strip() | |
| except ValueError: | |
| # Fallback si le format n'est pas trouvé (ce qui est rare) | |
| print("Avertissement: Format '### Response:' non trouvé dans la sortie.") | |
| clean_response = full_generated_text.replace(prompt, "").strip() | |
| return clean_response | |
| # ============================================================================== | |
| # 4. CRÉATION DE L'INTERFACE GRADIO | |
| # ============================================================================== | |
| # Le composant `gr.ChatInterface` est le plus simple et le plus adapté. | |
| chat_interface = gr.ChatInterface( | |
| fn=generate_response, | |
| chatbot=gr.Chatbot(height=500), # Taille du champ de chat | |
| textbox=gr.Textbox(placeholder="Posez votre question à Melta...", container=False, scale=7), | |
| title=f"Chat avec {MODEL_NAME} (SLM)", | |
| description="Interface de discussion pour le SLM Melta. L'inférence est effectuée sur CPU ou GPU si disponible.", | |
| # On ne met PLUS 'theme' ici ! | |
| ) | |
| # Lancement de l'application | |
| if __name__ == "__main__": | |
| # --- LA CORRECTION EST ICI --- | |
| # L'argument 'theme' doit être passé à .launch() pour personnaliser le style | |
| # L'argument 'server_name' est souvent utile pour les déploiements (ex: '0.0.0.0' pour HF Spaces) | |
| chat_interface.launch( | |
| theme="soft", # <--- CORRIGÉ | |
| server_name="0.0.0.0", | |
| share=False | |
| ) |