File size: 5,069 Bytes
df4b494 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import os
import sys
import pickle
import torch
import gradio as gr
from huggingface_hub import snapshot_download
# ======================
# CONFIGURACIÓN REPO HF
# ======================
REPO_ID = "teszenofficial/MTP7"
MODEL_FILE = "mtp_mini.pkl" # Asegúrate de que se llame así en tu repo
TOKENIZER_FILE = "mtp_tokenizer.model" # Asegúrate de que se llame así en tu repo
LOCAL_DIR = "mtptz_repo" # Nombre de la carpeta local donde se descarga
# ======================
# DESCARGA Y CARGA DEL MODELO
# ======================
def load_resources():
print(f"📦 Descargando modelo desde {REPO_ID}...")
# 1. Descargar el repositorio a una carpeta local
repo_path = snapshot_download(
repo_id=REPO_ID,
local_dir=LOCAL_DIR
)
print(f"✅ Modelo descargado en: {repo_path}")
# 2. Añadir la ruta al sys.path para poder importar model.py y tokenizer.py desde el repo
sys.path.insert(0, repo_path)
try:
# Intentamos importar las clases desde los archivos descargados en el repo
from model import MTPMiniModel
from tokenizer import MTPTokenizer
except ImportError as e:
print(f"❌ ERROR: No se pudieron importar 'model' o 'tokenizer'.")
print(f" Asegúrate de que subiste 'model.py' y 'tokenizer.py' al repo '{REPO_ID}'.")
raise e
# 3. Definir rutas completas
model_path = os.path.join(repo_path, MODEL_FILE)
tokenizer_path = os.path.join(repo_path, TOKENIZER_FILE)
# Verificar si existen
if not os.path.exists(model_path):
raise FileNotFoundError(f"No se encontró {MODEL_FILE} en el repo.")
if not os.path.exists(tokenizer_path):
raise FileNotFoundError(f"No se encontró {TOKENIZER_FILE} en el repo.")
# 4. Cargar Tokenizer
tokenizer = MTPTokenizer(tokenizer_path)
print(f"✅ Tokenizer cargado. Vocab size: {tokenizer.vocab_size()}")
# 5. Cargar Modelo
print(f"🧠 Cargando tensores...")
with open(model_path, 'rb') as f:
model_data = pickle.load(f)
config = model_data['config']
state_dict = model_data['model_state_dict']
vocab_size = model_data['vocab_size']
# Reconstruir el Modelo
use_swiglu = config['model'].get('use_swiglu', False)
model = MTPMiniModel(
vocab_size=vocab_size,
d_model=config['model']['d_model'],
n_layers=config['model']['n_layers'],
n_heads=config['model']['n_heads'],
d_ff=config['model']['d_ff'],
max_seq_len=config['model']['max_seq_len'],
dropout=0.0,
use_swiglu=use_swiglu
)
model.load_state_dict(state_dict)
model.eval()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(DEVICE)
print(f"✅ Modelo cargado en {DEVICE}")
return model, tokenizer, DEVICE
# Cargar al inicio
model, tokenizer, DEVICE = load_resources()
# ======================
# FUNCIÓN DE GENERACIÓN
# ======================
def generate_response(message, history, temperature, max_tokens, top_p):
# Construir el prompt
# Formato: ### Instrucción:\n{input}\n\n### Respuesta:\n
prompt = f"### Instrucción:\n{message}\n\n### Respuesta:\n"
# Tokenizar
tokens = [tokenizer.bos_id()] + tokenizer.encode(prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
# Generar usando el método del modelo
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_k=40,
top_p=float(top_p),
repetition_penalty=1.15,
min_length=10,
eos_token_id=tokenizer.eos_id()
)
# Decodificar
gen_tokens = output_ids[0, len(tokens):].tolist()
safe_tokens = []
for t in gen_tokens:
if 0 <= t < tokenizer.vocab_size() and t != tokenizer.eos_id():
safe_tokens.append(t)
elif t == tokenizer.eos_id():
break
response = tokenizer.decode(safe_tokens).strip()
# Limpieza básica
if "### Instrucción:" in response:
response = response.split("### Instrucción:")[0].strip()
return response
# ======================
# INTERFAZ GRADIO
# ======================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 MTP-7 Chat (Demo)")
gr.Markdown(f"Modelo cargado desde `teszenofficial/MTP7` en **{DEVICE}**.")
chat_interface = gr.ChatInterface(
fn=generate_response,
additional_inputs=[
gr.Slider(0.1, 2.0, value=0.7, label="Temperatura (Creatividad)"),
gr.Slider(50, 300, value=150, label="Máximos Tokens"),
gr.Slider(0.1, 1.0, value=0.92, label="Top-p (Nucleus)"),
],
examples=[
["¿Cuál es la capital de Francia?", 0.7, 150, 0.92],
["Explica qué es la relatividad.", 0.7, 150, 0.92]
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch() |