teszenofficial commited on
Commit
df4b494
·
verified ·
1 Parent(s): 5b237c4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pickle
4
+ import torch
5
+ import gradio as gr
6
+ from huggingface_hub import snapshot_download
7
+
8
+ # ======================
9
+ # CONFIGURACIÓN REPO HF
10
+ # ======================
11
+ REPO_ID = "teszenofficial/MTP7"
12
+ MODEL_FILE = "mtp_mini.pkl" # Asegúrate de que se llame así en tu repo
13
+ TOKENIZER_FILE = "mtp_tokenizer.model" # Asegúrate de que se llame así en tu repo
14
+ LOCAL_DIR = "mtptz_repo" # Nombre de la carpeta local donde se descarga
15
+
16
+ # ======================
17
+ # DESCARGA Y CARGA DEL MODELO
18
+ # ======================
19
+
20
+ def load_resources():
21
+ print(f"📦 Descargando modelo desde {REPO_ID}...")
22
+
23
+ # 1. Descargar el repositorio a una carpeta local
24
+ repo_path = snapshot_download(
25
+ repo_id=REPO_ID,
26
+ local_dir=LOCAL_DIR
27
+ )
28
+ print(f"✅ Modelo descargado en: {repo_path}")
29
+
30
+ # 2. Añadir la ruta al sys.path para poder importar model.py y tokenizer.py desde el repo
31
+ sys.path.insert(0, repo_path)
32
+
33
+ try:
34
+ # Intentamos importar las clases desde los archivos descargados en el repo
35
+ from model import MTPMiniModel
36
+ from tokenizer import MTPTokenizer
37
+ except ImportError as e:
38
+ print(f"❌ ERROR: No se pudieron importar 'model' o 'tokenizer'.")
39
+ print(f" Asegúrate de que subiste 'model.py' y 'tokenizer.py' al repo '{REPO_ID}'.")
40
+ raise e
41
+
42
+ # 3. Definir rutas completas
43
+ model_path = os.path.join(repo_path, MODEL_FILE)
44
+ tokenizer_path = os.path.join(repo_path, TOKENIZER_FILE)
45
+
46
+ # Verificar si existen
47
+ if not os.path.exists(model_path):
48
+ raise FileNotFoundError(f"No se encontró {MODEL_FILE} en el repo.")
49
+ if not os.path.exists(tokenizer_path):
50
+ raise FileNotFoundError(f"No se encontró {TOKENIZER_FILE} en el repo.")
51
+
52
+ # 4. Cargar Tokenizer
53
+ tokenizer = MTPTokenizer(tokenizer_path)
54
+ print(f"✅ Tokenizer cargado. Vocab size: {tokenizer.vocab_size()}")
55
+
56
+ # 5. Cargar Modelo
57
+ print(f"🧠 Cargando tensores...")
58
+ with open(model_path, 'rb') as f:
59
+ model_data = pickle.load(f)
60
+
61
+ config = model_data['config']
62
+ state_dict = model_data['model_state_dict']
63
+ vocab_size = model_data['vocab_size']
64
+
65
+ # Reconstruir el Modelo
66
+ use_swiglu = config['model'].get('use_swiglu', False)
67
+ model = MTPMiniModel(
68
+ vocab_size=vocab_size,
69
+ d_model=config['model']['d_model'],
70
+ n_layers=config['model']['n_layers'],
71
+ n_heads=config['model']['n_heads'],
72
+ d_ff=config['model']['d_ff'],
73
+ max_seq_len=config['model']['max_seq_len'],
74
+ dropout=0.0,
75
+ use_swiglu=use_swiglu
76
+ )
77
+
78
+ model.load_state_dict(state_dict)
79
+ model.eval()
80
+
81
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
82
+ model.to(DEVICE)
83
+ print(f"✅ Modelo cargado en {DEVICE}")
84
+
85
+ return model, tokenizer, DEVICE
86
+
87
+ # Cargar al inicio
88
+ model, tokenizer, DEVICE = load_resources()
89
+
90
+ # ======================
91
+ # FUNCIÓN DE GENERACIÓN
92
+ # ======================
93
+ def generate_response(message, history, temperature, max_tokens, top_p):
94
+ # Construir el prompt
95
+ # Formato: ### Instrucción:\n{input}\n\n### Respuesta:\n
96
+ prompt = f"### Instrucción:\n{message}\n\n### Respuesta:\n"
97
+
98
+ # Tokenizar
99
+ tokens = [tokenizer.bos_id()] + tokenizer.encode(prompt)
100
+ input_ids = torch.tensor([tokens], device=DEVICE)
101
+
102
+ # Generar usando el método del modelo
103
+ with torch.no_grad():
104
+ output_ids = model.generate(
105
+ input_ids,
106
+ max_new_tokens=int(max_tokens),
107
+ temperature=float(temperature),
108
+ top_k=40,
109
+ top_p=float(top_p),
110
+ repetition_penalty=1.15,
111
+ min_length=10,
112
+ eos_token_id=tokenizer.eos_id()
113
+ )
114
+
115
+ # Decodificar
116
+ gen_tokens = output_ids[0, len(tokens):].tolist()
117
+ safe_tokens = []
118
+ for t in gen_tokens:
119
+ if 0 <= t < tokenizer.vocab_size() and t != tokenizer.eos_id():
120
+ safe_tokens.append(t)
121
+ elif t == tokenizer.eos_id():
122
+ break
123
+
124
+ response = tokenizer.decode(safe_tokens).strip()
125
+
126
+ # Limpieza básica
127
+ if "### Instrucción:" in response:
128
+ response = response.split("### Instrucción:")[0].strip()
129
+
130
+ return response
131
+
132
+ # ======================
133
+ # INTERFAZ GRADIO
134
+ # ======================
135
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
+ gr.Markdown("# 🤖 MTP-7 Chat (Demo)")
137
+ gr.Markdown(f"Modelo cargado desde `teszenofficial/MTP7` en **{DEVICE}**.")
138
+
139
+ chat_interface = gr.ChatInterface(
140
+ fn=generate_response,
141
+ additional_inputs=[
142
+ gr.Slider(0.1, 2.0, value=0.7, label="Temperatura (Creatividad)"),
143
+ gr.Slider(50, 300, value=150, label="Máximos Tokens"),
144
+ gr.Slider(0.1, 1.0, value=0.92, label="Top-p (Nucleus)"),
145
+ ],
146
+ examples=[
147
+ ["¿Cuál es la capital de Francia?", 0.7, 150, 0.92],
148
+ ["Explica qué es la relatividad.", 0.7, 150, 0.92]
149
+ ],
150
+ cache_examples=False
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()