File size: 6,751 Bytes
3dc8757
 
 
 
 
 
 
71ddbc0
3dc8757
 
 
 
 
71ddbc0
 
 
3dc8757
 
 
 
 
 
 
 
 
 
71ddbc0
3dc8757
 
 
71ddbc0
3dc8757
 
71ddbc0
 
 
3dc8757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ddbc0
 
3dc8757
 
 
 
 
 
71ddbc0
3dc8757
 
 
 
 
 
 
 
 
 
 
 
 
71ddbc0
 
 
3dc8757
 
71ddbc0
3dc8757
 
71ddbc0
3dc8757
71ddbc0
 
3dc8757
71ddbc0
 
3dc8757
 
 
71ddbc0
3dc8757
71ddbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc8757
71ddbc0
 
3dc8757
 
71ddbc0
 
 
 
 
3dc8757
 
 
 
 
 
 
 
 
 
 
 
 
71ddbc0
 
3dc8757
 
 
 
71ddbc0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# app.py
# ==============================================================================
# 1. IMPORTS
# ==============================================================================
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import os # Utile pour gérer les chemins si nécessaire.

# ==============================================================================
# 2. DÉFINITION DES CONSTANTES ET CHARGEMENT (Optimisé pour un Space)
# ==============================================================================
MODEL_NAME = "Clem27AI/Melta27"
# Détermination du DEVICE : Priorité au GPU si disponible (pour un Space), sinon CPU.
# Note : Pour les petits modèles sur Hugging Face Spaces, le CPU est souvent la seule option si pas de GPU spécifié.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Chargement du modèle {MODEL_NAME} sur {DEVICE}...")

# Chargement du tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Assurez-vous d'avoir un token de padding/fin de séquence
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Chargement du modèle. On utilise device_map="auto" pour laisser Transformers
# gérer l'emplacement optimal et le déchargement/chargement si nécessaire.
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float32, # float32 est souvent suffisant pour le CPU
        device_map="auto" # Laisse Hugging Face gérer le device pour l'hébergement
    )
    # Si device_map n'a pas tout géré (ex: model.to(DEVICE) pour un CPU explicite)
    # Ligne commentée car device_map="auto" est la meilleure pratique.
    # model.to(DEVICE)
    print("Modèle chargé avec succès.")

except Exception as e:
    print(f"Erreur lors du chargement du modèle : {e}")
    exit()

# ==============================================================================
# 3. FONCTION D'INFÉRENCE POUR GRADIO CHAT
# ==============================================================================

def format_prompt(history, message):
    """
    Formate la conversation complète pour le modèle SLM dans le format :
    ### Instruction:
    [HISTORIQUE DE LA CONVERSATION]
    [NOUVELLE QUESTION]
    
    ### Response:
    """
    
    # 1. Construire l'historique complet pour le contexte
    full_history = ""
    # history est une liste de paires [utilisateur, bot]
    for user_msg, bot_msg in history:
        # On assume un format simple de Questions/Réponses dans l'historique
        # Note: L'ajout de "\n" avant le bot_msg est important si le modèle
        # a été entraîné avec ce type de saut de ligne.
        full_history += f"### Instruction: {user_msg}\n\n### Response: {bot_msg}\n"
    
    # 2. Ajouter la nouvelle question de l'utilisateur
    full_prompt = (
        f"{full_history}" # L'historique des tours précédents
        f"### Instruction: {message}\n\n" # La nouvelle question
        f"### Response:" # Le modèle doit continuer à partir d'ici
    )
    
    return full_prompt.strip()


def generate_response(message, history):
    """
    Fonction principale appelée par l'interface Gradio Chat.
    """
    
    # 1. Formatage du prompt complet avec l'historique
    prompt = format_prompt(history, message)
    
    # 2. Tokenization et placement sur le device
    # On ajoute la ligne 'prompt' pour que le modèle commence à générer après.
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Utiliser model.device pour la cohérence

    # 3. Génération
    with torch.no_grad():
        output_tokens = model.generate(
            **inputs,
            max_new_tokens=100, # Augmenté la longueur pour des réponses plus complètes
            do_sample=True,
            temperature=0.7, # Augmenté légèrement pour plus de créativité
            top_k=50,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=True,
        )

    # 4. Décodage et nettoyage
    # Le tenseur 'output_tokens' contient [prompt_tokens | generated_tokens]
    
    # On décode tout le résultat
    full_generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    
    # Pour s'assurer de n'avoir que la réponse du bot, on retire le prompt initial
    # qui a servi d'entrée à la fonction 'generate'.
    
    # On doit s'assurer que l'objet 'prompt' n'a pas de caractères spéciaux 
    # ou de multiples espaces qui empêcheraient la correspondance exacte.
    
    # Mieux vaut chercher l'indice de la dernière occurrence du marqueur '### Response:'
    
    assistant_prefix = "### Response:"
    
    # On cherche le point de départ de la réponse générée (le prompt se termine par assistant_prefix)
    try:
        # On ajoute la longueur de l'assistant_prefix pour commencer JUSTE après
        start_index = full_generated_text.rindex(assistant_prefix) + len(assistant_prefix)
        clean_response = full_generated_text[start_index:].strip()
        
        # Le modèle peut parfois continuer et générer le prochain marqueur de prompt.
        # On arrête la réponse avant cela.
        if '### Instruction:' in clean_response:
             clean_response = clean_response.split('### Instruction:')[0].strip()
        
    except ValueError:
        # Fallback si le format n'est pas trouvé (ce qui est rare)
        print("Avertissement: Format '### Response:' non trouvé dans la sortie.")
        clean_response = full_generated_text.replace(prompt, "").strip()
        
    return clean_response

# ==============================================================================
# 4. CRÉATION DE L'INTERFACE GRADIO
# ==============================================================================

# Le composant `gr.ChatInterface` est le plus simple et le plus adapté.
chat_interface = gr.ChatInterface(
    fn=generate_response,
    chatbot=gr.Chatbot(height=500), # Taille du champ de chat
    textbox=gr.Textbox(placeholder="Posez votre question à Melta...", container=False, scale=7),
    title=f"Chat avec {MODEL_NAME} (SLM)",
    description="Interface de discussion pour le SLM Melta. L'inférence est effectuée sur CPU ou GPU si disponible.",
    # On ne met PLUS 'theme' ici !
)

# Lancement de l'application
if __name__ == "__main__":
    # --- LA CORRECTION EST ICI ---
    # L'argument 'theme' doit être passé à .launch() pour personnaliser le style
    # L'argument 'server_name' est souvent utile pour les déploiements (ex: '0.0.0.0' pour HF Spaces)
    chat_interface.launch(
        theme="soft", # <--- CORRIGÉ
        server_name="0.0.0.0",
        share=False
    )