import os import sys import torch import pickle import time import gc from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from huggingface_hub import snapshot_download import uvicorn # ====================== # CONFIGURACIÓN DE DISPOSITIVO # ====================== if torch.cuda.is_available(): DEVICE = "cuda" print("✅ GPU NVIDIA detectada. Usando CUDA.") else: DEVICE = "cpu" print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).") # Optimización de hilos para CPU if DEVICE == "cpu": torch.set_num_threads(max(1, os.cpu_count() // 2)) torch.set_grad_enabled(False) MODEL_REPO = "TeszenAI/mtp-3.1" # ====================== # DESCARGA Y CARGA DEL MODELO # ====================== print(f"📦 Descargando modelo desde {MODEL_REPO}...") repo_path = snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir="mtptz_repo" ) sys.path.insert(0, repo_path) # Importar modelo mejorado compatible from model import MTPMiniModel from tokenizer import MTPTokenizer print("🔧 Cargando tensores y configuración...") with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f: model_data = pickle.load(f) tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model")) VOCAB_SIZE = tokenizer.sp.get_piece_size() config = model_data["config"] # Detectar si el modelo usa SwiGLU use_swiglu = config["model"].get("use_swiglu", False) print(f"🧠 Inicializando modelo...") print(f" → Vocabulario: {VOCAB_SIZE}") print(f" → Dimensión: {config['model']['d_model']}") print(f" → Capas: {config['model']['n_layers']}") print(f" → SwiGLU: {'✓' if use_swiglu else '✗'}") model = MTPMiniModel( vocab_size=VOCAB_SIZE, d_model=config["model"]["d_model"], n_layers=config["model"]["n_layers"], n_heads=config["model"]["n_heads"], d_ff=config["model"]["d_ff"], max_seq_len=config["model"]["max_seq_len"], dropout=0.0, use_swiglu=use_swiglu # NUEVO: soporte para SwiGLU ) model.load_state_dict(model_data["model_state_dict"]) model.eval() # Cuantización para CPU if DEVICE == "cpu": print("⚡ Aplicando cuantización dinámica para CPU...") model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) model.to(DEVICE) param_count = sum(p.numel() for p in model.parameters()) print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)") # ====================== # API CONFIG # ====================== app = FastAPI( title="MTP-3 API", description="API para modelo de lenguaje MTP-3 mejorado", version="3.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class PromptRequest(BaseModel): text: str = Field(..., max_length=2000, description="Texto de entrada") max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar") temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo") top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling") top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling") repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición") def build_prompt(user_input: str) -> str: """Construye el prompt en el formato del modelo""" return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n" # ====================== # ⚡ GESTIÓN DE CARGA # ====================== ACTIVE_REQUESTS = 0 @app.post("/generate") async def generate(req: PromptRequest): """Endpoint principal de generación de texto""" global ACTIVE_REQUESTS ACTIVE_REQUESTS += 1 # Ajuste dinámico bajo carga dyn_max_tokens = req.max_tokens dyn_temperature = req.temperature if ACTIVE_REQUESTS > 2: print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.") dyn_max_tokens = min(dyn_max_tokens, 120) dyn_temperature = max(0.5, dyn_temperature * 0.9) user_input = req.text.strip() if not user_input: ACTIVE_REQUESTS -= 1 return {"reply": "", "tokens_generated": 0} full_prompt = build_prompt(user_input) tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) try: with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=dyn_max_tokens, temperature=dyn_temperature, top_k=req.top_k, top_p=req.top_p, repetition_penalty=req.repetition_penalty ) gen_tokens = output_ids[0, len(tokens):].tolist() # Filtro de seguridad safe_tokens = [ t for t in gen_tokens if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id() ] response = tokenizer.decode(safe_tokens).strip() # Limpiar marcadores de sección if "###" in response: response = response.split("###")[0].strip() return { "reply": response, "tokens_generated": len(safe_tokens), "model": "MTP-3" } except Exception as e: print(f"❌ Error durante generación: {e}") return { "reply": "Lo siento, ocurrió un error al procesar tu solicitud.", "error": str(e) } finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() # ====================== # 📡 STREAMING SSE # ====================== @app.get("/generate_sse") def generate_sse( text: str, max_tokens: int = 150, temperature: float = 0.7 ): """Endpoint de streaming con Server-Sent Events""" global ACTIVE_REQUESTS ACTIVE_REQUESTS += 1 def event_stream(): try: full_prompt = build_prompt(text) tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) # Ajuste dinámico limit = 100 if ACTIVE_REQUESTS > 2 else max_tokens temp = max(0.5, temperature * 0.9) if ACTIVE_REQUESTS > 2 else temperature for step in range(limit): with torch.no_grad(): logits, _ = model(input_ids) logits = logits[:, -1, :VOCAB_SIZE] # Sampling con temperatura probs = torch.softmax(logits / temp, dim=-1) next_id = torch.multinomial(probs, num_samples=1).item() if next_id == tokenizer.eos_id(): break if 0 <= next_id < VOCAB_SIZE: token_text = tokenizer.decode([next_id]) # Limpiar salida if "###" in token_text: break yield f"data:{token_text}\n\n" input_ids = torch.cat( [input_ids, torch.tensor([[next_id]], device=DEVICE)], dim=1 ) time.sleep(0.01) yield "data:[DONE]\n\n" except Exception as e: yield f"data:[ERROR: {str(e)}]\n\n" finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() return StreamingResponse(event_stream(), media_type="text/event-stream") # ====================== # 📊 ENDPOINTS DE INFORMACIÓN # ====================== @app.get("/health") def health_check(): """Check del estado del servicio""" return { "status": "healthy", "model": "MTP-3", "device": DEVICE, "active_requests": ACTIVE_REQUESTS, "vocab_size": VOCAB_SIZE, "parameters": sum(p.numel() for p in model.parameters()) } @app.get("/info") def model_info(): """Información detallada del modelo""" return { "model_name": "MTP-3", "version": "3.0", "architecture": { "d_model": config["model"]["d_model"], "n_layers": config["model"]["n_layers"], "n_heads": config["model"]["n_heads"], "d_ff": config["model"]["d_ff"], "max_seq_len": config["model"]["max_seq_len"], "vocab_size": VOCAB_SIZE, "use_swiglu": config["model"].get("use_swiglu", False) }, "parameters": sum(p.numel() for p in model.parameters()), "device": DEVICE, "improvements": [ "RoPE (Rotary Position Embedding)", "RMSNorm", "Label Smoothing", "Repetition Penalty", "SwiGLU (opcional)" if config["model"].get("use_swiglu") else None ] } # ====================== # 🎨 INTERFAZ WEB # ====================== @app.get("/", response_class=HTMLResponse) def chat_ui(): return """ MTP 3
MTP 3
¡Hola! Soy MTP 3. ¿En qué puedo ayudarte hoy?
""" if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) print(f"\n🚀 Iniciando servidor en puerto {port}...") print(f"🌐 Interfaz web: http://0.0.0.0:{port}") print(f"📡 API docs: http://0.0.0.0:{port}/docs") uvicorn.run( app, host="0.0.0.0", port=port, log_level="info" )