import os import sys import torch import pickle import time from fastapi import FastAPI from fastapi.responses import HTMLResponse, StreamingResponse from pydantic import BaseModel from huggingface_hub import snapshot_download import uvicorn # ====================== # OPTIMIZACIÓN CPU # ====================== torch.set_num_threads(max(1, os.cpu_count() // 2)) torch.set_grad_enabled(False) # ====================== # DISPOSITIVO # ====================== if torch.cuda.is_available(): DEVICE = "cuda" print("✅ GPU detectada. Usando CUDA.") else: DEVICE = "cpu" print("⚠️ GPU no detectada. Usando CPU.") MODEL_REPO = "teszenofficial/mtp1" # ====================== # DESCARGA MODELO # ====================== print("--- SISTEMA MTP 1.1 ---") repo_path = snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir="mtptz_repo" ) sys.path.insert(0, repo_path) from model import MTPMiniModel from tokenizer import MTPTokenizer # ====================== # CARGA MODELO # ====================== print("Cargando modelo...") pkl_file = next(f for f in os.listdir(repo_path) if f.endswith(".pkl")) with open(os.path.join(repo_path, pkl_file), "rb") as f: model_data = pickle.load(f) tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model")) config = model_data["config"] model = MTPMiniModel( vocab_size=model_data["vocab_size"], d_model=config["model"]["d_model"], n_layers=config["model"]["n_layers"], n_heads=config["model"]["n_heads"], d_ff=config["model"]["d_ff"], max_seq_len=config["model"]["max_seq_len"], dropout=0.0 ) model.load_state_dict(model_data["model_state_dict"]) model.to(DEVICE) model.eval() VOCAB_SIZE = tokenizer.sp.get_piece_size() model.vocab_size = VOCAB_SIZE print(f"🚀 MTP 1.1 listo en {DEVICE.upper()}") # ====================== # FASTAPI # ====================== app = FastAPI(title="MTP 1.1") class Prompt(BaseModel): text: str # ====================== # GENERACIÓN NORMAL (NO STREAM) # ====================== @app.post("/generate") def generate(prompt: Prompt): try: text = prompt.text.strip() if not text: return {"reply": ""} full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n" tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=80, temperature=0.7, top_k=50, top_p=0.9 ) gen = output[0, len(tokens):].tolist() safe = [t for t in gen if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id()] reply = tokenizer.decode(safe).strip() return {"reply": reply} except Exception as e: print("❌ ERROR:", e) return {"reply": "Error interno."} # ====================== # GENERACIÓN STREAMING (TIPO CHATGPT) # ====================== @app.post("/generate_stream") def generate_stream(prompt: Prompt): def stream(): try: text = prompt.text.strip() full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n" tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) for _ in range(80): with torch.no_grad(): logits = model(input_ids)[:, -1, :] logits = logits[:, :VOCAB_SIZE] probs = torch.softmax(logits / 0.7, dim=-1) next_id = torch.argmax(probs, dim=-1).item() if next_id == tokenizer.eos_id(): break if 0 <= next_id < VOCAB_SIZE: token_text = tokenizer.decode([next_id]) yield token_text input_ids = torch.cat( [input_ids, torch.tensor([[next_id]], device=DEVICE)], dim=1 ) time.sleep(0.015) except Exception as e: print("❌ STREAM ERROR:", e) yield "\n[error]" return StreamingResponse(stream(), media_type="text/plain") # ====================== # FRONTEND HTML COMPLETO # ====================== @app.get("/", response_class=HTMLResponse) def ui(): return """ MTP 1.1
Hola, soy MTP 1.1.
""" # ====================== # ENTRYPOINT # ====================== if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)