Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import torch | |
| import pickle | |
| import time | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| from huggingface_hub import snapshot_download | |
| import uvicorn | |
| # ====================== | |
| # OPTIMIZACIÓN CPU | |
| # ====================== | |
| torch.set_num_threads(max(1, os.cpu_count() // 2)) | |
| torch.set_grad_enabled(False) | |
| # ====================== | |
| # DISPOSITIVO | |
| # ====================== | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| print("✅ GPU detectada. Usando CUDA.") | |
| else: | |
| DEVICE = "cpu" | |
| print("⚠️ GPU no detectada. Usando CPU.") | |
| MODEL_REPO = "teszenofficial/mtp1" | |
| # ====================== | |
| # DESCARGA MODELO | |
| # ====================== | |
| print("--- SISTEMA MTP 1.1 ---") | |
| repo_path = snapshot_download( | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| local_dir="mtptz_repo" | |
| ) | |
| sys.path.insert(0, repo_path) | |
| from model import MTPMiniModel | |
| from tokenizer import MTPTokenizer | |
| # ====================== | |
| # CARGA MODELO | |
| # ====================== | |
| print("Cargando modelo...") | |
| pkl_file = next(f for f in os.listdir(repo_path) if f.endswith(".pkl")) | |
| with open(os.path.join(repo_path, pkl_file), "rb") as f: | |
| model_data = pickle.load(f) | |
| tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model")) | |
| config = model_data["config"] | |
| model = MTPMiniModel( | |
| vocab_size=model_data["vocab_size"], | |
| d_model=config["model"]["d_model"], | |
| n_layers=config["model"]["n_layers"], | |
| n_heads=config["model"]["n_heads"], | |
| d_ff=config["model"]["d_ff"], | |
| max_seq_len=config["model"]["max_seq_len"], | |
| dropout=0.0 | |
| ) | |
| model.load_state_dict(model_data["model_state_dict"]) | |
| model.to(DEVICE) | |
| model.eval() | |
| VOCAB_SIZE = tokenizer.sp.get_piece_size() | |
| model.vocab_size = VOCAB_SIZE | |
| print(f"🚀 MTP 1.1 listo en {DEVICE.upper()}") | |
| # ====================== | |
| # FASTAPI | |
| # ====================== | |
| app = FastAPI(title="MTP 1.1") | |
| class Prompt(BaseModel): | |
| text: str | |
| # ====================== | |
| # GENERACIÓN NORMAL (NO STREAM) | |
| # ====================== | |
| def generate(prompt: Prompt): | |
| try: | |
| text = prompt.text.strip() | |
| if not text: | |
| return {"reply": ""} | |
| full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n" | |
| tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) | |
| input_ids = torch.tensor([tokens], device=DEVICE) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=80, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.9 | |
| ) | |
| gen = output[0, len(tokens):].tolist() | |
| safe = [t for t in gen if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id()] | |
| reply = tokenizer.decode(safe).strip() | |
| return {"reply": reply} | |
| except Exception as e: | |
| print("❌ ERROR:", e) | |
| return {"reply": "Error interno."} | |
| # ====================== | |
| # GENERACIÓN STREAMING (TIPO CHATGPT) | |
| # ====================== | |
| def generate_stream(prompt: Prompt): | |
| def stream(): | |
| try: | |
| text = prompt.text.strip() | |
| full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n" | |
| tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) | |
| input_ids = torch.tensor([tokens], device=DEVICE) | |
| for _ in range(80): | |
| with torch.no_grad(): | |
| logits = model(input_ids)[:, -1, :] | |
| logits = logits[:, :VOCAB_SIZE] | |
| probs = torch.softmax(logits / 0.7, dim=-1) | |
| next_id = torch.argmax(probs, dim=-1).item() | |
| if next_id == tokenizer.eos_id(): | |
| break | |
| if 0 <= next_id < VOCAB_SIZE: | |
| token_text = tokenizer.decode([next_id]) | |
| yield token_text | |
| input_ids = torch.cat( | |
| [input_ids, torch.tensor([[next_id]], device=DEVICE)], | |
| dim=1 | |
| ) | |
| time.sleep(0.015) | |
| except Exception as e: | |
| print("❌ STREAM ERROR:", e) | |
| yield "\n[error]" | |
| return StreamingResponse(stream(), media_type="text/plain") | |
| # ====================== | |
| # FRONTEND HTML COMPLETO | |
| # ====================== | |
| def ui(): | |
| return """ | |
| <!DOCTYPE html> | |
| <html lang="es"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width,initial-scale=1"> | |
| <title>MTP 1.1</title> | |
| <style> | |
| body{margin:0;background:#131314;color:#e3e3e3;font-family:Inter,system-ui} | |
| #chat{max-width:900px;margin:auto;padding:20px} | |
| .msg{margin:12px 0;white-space:pre-wrap} | |
| .user{color:#8ab4f8} | |
| .bot{color:#e3e3e3} | |
| input{width:100%;padding:12px;border-radius:10px;border:none;background:#1e1f20;color:white} | |
| button{margin-top:10px;padding:10px;border-radius:10px;border:none;background:#4a9eff;color:black;font-weight:bold} | |
| </style> | |
| </head> | |
| <body> | |
| <div id="chat"> | |
| <div class="msg bot">Hola, soy MTP 1.1.</div> | |
| </div> | |
| <input id="inp" placeholder="Escribe algo…" /> | |
| <button onclick="send()">Enviar</button> | |
| <script> | |
| async function send(){ | |
| const input=document.getElementById('inp'); | |
| const text=input.value.trim(); | |
| if(!text)return; | |
| input.value=""; | |
| const chat=document.getElementById('chat'); | |
| chat.innerHTML+=`<div class="msg user">${text}</div>`; | |
| const bot=document.createElement('div'); | |
| bot.className="msg bot"; | |
| chat.appendChild(bot); | |
| const res=await fetch('/generate_stream',{ | |
| method:'POST', | |
| headers:{'Content-Type':'application/json'}, | |
| body:JSON.stringify({text}) | |
| }); | |
| const reader=res.body.getReader(); | |
| const decoder=new TextDecoder(); | |
| while(true){ | |
| const {value,done}=await reader.read(); | |
| if(done)break; | |
| bot.textContent+=decoder.decode(value); | |
| window.scrollTo(0,document.body.scrollHeight); | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # ====================== | |
| # ENTRYPOINT | |
| # ====================== | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |