import os import sys import torch import pickle import time import gc from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from huggingface_hub import snapshot_download import uvicorn # ====================== # CONFIGURACIÓN DE DISPOSITIVO # ====================== if torch.cuda.is_available(): DEVICE = "cuda" print("✅ GPU NVIDIA detectada. Usando CUDA.") else: DEVICE = "cpu" print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).") # Optimización de hilos para CPU if DEVICE == "cpu": torch.set_num_threads(max(1, os.cpu_count() // 2)) torch.set_grad_enabled(False) MODEL_REPO = "TeszenAI/MTP-4" # ====================== # DESCARGA Y CARGA DEL MODELO # ====================== print(f"📦 Descargando modelo desde {MODEL_REPO}...") repo_path = snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo" ) sys.path.insert(0, repo_path) # Importar modelo y tokenizer from model import MTPMiniModel from tokenizer import MTPTokenizer print("🔧 Cargando tensores y configuración...") with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f: model_data = pickle.load(f) tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model")) VOCAB_SIZE = tokenizer.sp.get_piece_size() config = model_data["config"] # Detectar si el modelo usa SwiGLU use_swiglu = config["model"].get("use_swiglu", False) print(f"🧠 Inicializando modelo MTP 4...") print(f" → Vocabulario: {VOCAB_SIZE}") print(f" → Dimensión: {config['model']['d_model']}") print(f" → Capas: {config['model']['n_layers']}") print(f" → Cabezas: {config['model']['n_heads']}") print(f" → SwiGLU: {'✓' if use_swiglu else '✗'}") model = MTPMiniModel( vocab_size=VOCAB_SIZE, d_model=config["model"]["d_model"], n_layers=config["model"]["n_layers"], n_heads=config["model"]["n_heads"], d_ff=config["model"]["d_ff"], max_seq_len=config["model"]["max_seq_len"], dropout=0.0, use_swiglu=use_swiglu ) model.load_state_dict(model_data["model_state_dict"]) model.eval() # Cuantización para CPU if DEVICE == "cpu": print("⚡ Aplicando cuantización dinámica para CPU...") model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) model.to(DEVICE) param_count = sum(p.numel() for p in model.parameters()) print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)") # ====================== # API CONFIG # ====================== app = FastAPI( title="MTP 4 API", description="API para modelo de lenguaje MTP 4 con RoPE, RMSNorm y SwiGLU", version="4.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class PromptRequest(BaseModel): text: str = Field(..., max_length=2000, description="Texto de entrada") max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar") temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo") top_k: int = Field(default=40, ge=1, le=100, description="Top-k sampling") top_p: float = Field(default=0.92, ge=0.1, le=1.0, description="Top-p (nucleus) sampling") repetition_penalty: float = Field(default=1.15, ge=1.0, le=2.0, description="Penalización por repetición") min_length: int = Field(default=20, ge=5, le=100, description="Longitud mínima de respuesta") def build_prompt(user_input: str) -> str: """Construye el prompt en el formato del modelo""" return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n" # ====================== # ⚡ GESTIÓN DE CARGA # ====================== ACTIVE_REQUESTS = 0 MAX_CONCURRENT_REQUESTS = 3 @app.post("/generate") async def generate(req: PromptRequest): """Endpoint principal de generación de texto con control de calidad""" global ACTIVE_REQUESTS if ACTIVE_REQUESTS >= MAX_CONCURRENT_REQUESTS: return { "reply": "El servidor está ocupado. Por favor, intenta de nuevo en unos segundos.", "error": "too_many_requests", "active_requests": ACTIVE_REQUESTS } ACTIVE_REQUESTS += 1 # Ajuste dinámico bajo carga dyn_max_tokens = req.max_tokens dyn_temperature = req.temperature if ACTIVE_REQUESTS > 1: print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.") dyn_max_tokens = min(dyn_max_tokens, 120) dyn_temperature = max(0.6, dyn_temperature * 0.95) user_input = req.text.strip() if not user_input: ACTIVE_REQUESTS -= 1 return {"reply": "", "tokens_generated": 0} full_prompt = build_prompt(user_input) tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) try: start_time = time.time() with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=dyn_max_tokens, temperature=dyn_temperature, top_k=req.top_k, top_p=req.top_p, repetition_penalty=req.repetition_penalty, min_length=req.min_length, eos_token_id=tokenizer.eos_id() ) gen_tokens = output_ids[0, len(tokens):].tolist() # Filtro de seguridad mejorado safe_tokens = [] for t in gen_tokens: if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id(): safe_tokens.append(t) elif t == tokenizer.eos_id(): break response = tokenizer.decode(safe_tokens).strip() # Limpiar marcadores de sección if "###" in response: response = response.split("###")[0].strip() # Remover repeticiones al final if response.endswith(("...", ". . .", "…")): response = response.rstrip(".") generation_time = time.time() - start_time tokens_per_second = len(safe_tokens) / generation_time if generation_time > 0 else 0 return { "reply": response, "tokens_generated": len(safe_tokens), "generation_time": round(generation_time, 2), "tokens_per_second": round(tokens_per_second, 1), "model": "MTP-4", "device": DEVICE } except Exception as e: print(f"❌ Error durante generación: {e}") import traceback traceback.print_exc() return { "reply": "Lo siento, ocurrió un error al procesar tu solicitud.", "error": str(e) } finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() # ====================== # 📡 STREAMING SSE # ====================== @app.get("/generate_sse") def generate_sse( text: str, max_tokens: int = 150, temperature: float = 0.7, top_k: int = 40, top_p: float = 0.92, repetition_penalty: float = 1.15 ): """Endpoint de streaming con Server-Sent Events mejorado""" global ACTIVE_REQUESTS if ACTIVE_REQUESTS >= MAX_CONCURRENT_REQUESTS: def error_stream(): yield "data:[ERROR: Servidor ocupado]\n\n" return StreamingResponse(error_stream(), media_type="text/event-stream") ACTIVE_REQUESTS += 1 def event_stream(): try: full_prompt = build_prompt(text) tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) generated_tokens = [] # Ajuste dinámico limit = min(100 if ACTIVE_REQUESTS > 1 else max_tokens, 200) temp = max(0.6, temperature * 0.95) if ACTIVE_REQUESTS > 1 else temperature for step in range(limit): with torch.no_grad(): logits, _ = model(input_ids) logits = logits[:, -1, :VOCAB_SIZE].clone() # Aplicar repetition penalty if repetition_penalty != 1.0: for token_id in set(input_ids[0].tolist()): if logits[0, token_id] < 0: logits[0, token_id] *= repetition_penalty else: logits[0, token_id] /= repetition_penalty # Temperature scaling logits = logits / temp # Top-k filtering if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') # Sample probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1).item() if next_id == tokenizer.eos_id(): break if 0 <= next_id < VOCAB_SIZE: generated_tokens.append(next_id) token_text = tokenizer.decode([next_id]) # Limpiar salida if "###" in token_text: break yield f"data:{token_text}\n\n" input_ids = torch.cat( [input_ids, torch.tensor([[next_id]], device=DEVICE)], dim=1 ) time.sleep(0.02) # Control de velocidad yield "data:[DONE]\n\n" except Exception as e: print(f"❌ Error en streaming: {e}") yield f"data:[ERROR: {str(e)}]\n\n" finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() return StreamingResponse(event_stream(), media_type="text/event-stream") # ====================== # 📊 ENDPOINTS DE INFORMACIÓN # ====================== @app.get("/health") def health_check(): """Check del estado del servicio""" memory_info = {} if DEVICE == "cuda": memory_info = { "gpu_memory_allocated_mb": round(torch.cuda.memory_allocated() / 1024**2, 2), "gpu_memory_reserved_mb": round(torch.cuda.memory_reserved() / 1024**2, 2) } return { "status": "healthy", "model": "MTP-4", "device": DEVICE, "active_requests": ACTIVE_REQUESTS, "max_concurrent_requests": MAX_CONCURRENT_REQUESTS, "vocab_size": VOCAB_SIZE, "parameters": sum(p.numel() for p in model.parameters()), **memory_info } @app.get("/info") def model_info(): """Información detallada del modelo""" improvements = [ "RoPE (Rotary Position Embedding)", "RMSNorm (Root Mean Square Normalization)", "Label Smoothing (0.1)", "Repetition Penalty", "Early Stopping", "EOS Loss Weight", "Length Control", "Gradient Accumulation" ] if config["model"].get("use_swiglu", False): improvements.append("SwiGLU Activation") return { "model_name": "MTP-4", "version": "4.0", "architecture": { "d_model": config["model"]["d_model"], "n_layers": config["model"]["n_layers"], "n_heads": config["model"]["n_heads"], "d_ff": config["model"]["d_ff"], "max_seq_len": config["model"]["max_seq_len"], "vocab_size": VOCAB_SIZE, "use_swiglu": config["model"].get("use_swiglu", False), "dropout": config["model"]["dropout"] }, "parameters": sum(p.numel() for p in model.parameters()), "parameters_human": f"{sum(p.numel() for p in model.parameters())/1e6:.1f}M", "device": DEVICE, "improvements": improvements, "training_config": { "batch_size": config["training"]["batch_size"], "accumulation_steps": config["training"]["accumulation_steps"], "learning_rate": config["training"]["learning_rate"], "weight_decay": config["training"]["weight_decay"], "epochs": config["training"]["epochs"] } } @app.get("/config") def get_config(): """Obtener configuración completa del modelo""" return { "model": config["model"], "training": config["training"], "data": config["data"], "generation": config.get("generation", {}) } # ====================== # 🎨 INTERFAZ WEB MEJORADA # ====================== @app.get("/", response_class=HTMLResponse) def chat_ui(): return """ MTP 4 - Chat Interface
MTP 4.0
¡Hola! Soy MTP 4, un modelo de lenguaje avanzado con arquitectura Transformer optimizada. Características principales: • RoPE - Rotary Position Embedding para mejor contexto • RMSNorm - Normalización estable y eficiente • SwiGLU - Función de activación mejorada • Control inteligente de repetición y coherencia • Generación fluida y natural ¿En qué puedo ayudarte hoy?
""" if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) print(f"\n🚀 Iniciando servidor MTP 4...") print(f"🌐 Interfaz web: http://0.0.0.0:{port}") print(f"📡 API docs: http://0.0.0.0:{port}/docs") print(f"📊 Health check: http://0.0.0.0:{port}/health") print(f"ℹ️ Model info: http://0.0.0.0:{port}/info") print(f"\n✅ Sistema listo. Presiona Ctrl+C para detener.") uvicorn.run( app, host="0.0.0.0", port=port, log_level="info" )