| | import os |
| | import sys |
| | import torch |
| | import pickle |
| | import time |
| | import gc |
| | from fastapi import FastAPI, Request |
| | from fastapi.responses import HTMLResponse, StreamingResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | from huggingface_hub import snapshot_download |
| | import uvicorn |
| |
|
| | |
| | |
| | |
| | if torch.cuda.is_available(): |
| | DEVICE = "cuda" |
| | print("✅ GPU NVIDIA detectada. Usando CUDA.") |
| | else: |
| | DEVICE = "cpu" |
| | print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).") |
| |
|
| | |
| | if DEVICE == "cpu": |
| | torch.set_num_threads(max(1, os.cpu_count() // 2)) |
| |
|
| | torch.set_grad_enabled(False) |
| |
|
| | MODEL_REPO = "TeszenAI/MTP-3" |
| |
|
| | |
| | |
| | |
| | print(f"📦 Descargando modelo desde {MODEL_REPO}...") |
| | repo_path = snapshot_download( |
| | repo_id=MODEL_REPO, |
| | repo_type="model", |
| | local_dir="mtptz_repo" |
| | ) |
| |
|
| | sys.path.insert(0, repo_path) |
| |
|
| | |
| | from model import MTPModel |
| | from tokenizer import MTPTokenizer |
| |
|
| | print("🔧 Cargando tensores y configuración...") |
| |
|
| | |
| | map_location = torch.device('cpu') |
| |
|
| | try: |
| | |
| | model_data = torch.load( |
| | os.path.join(repo_path, "mtp3.pkl"), |
| | map_location=map_location, |
| | weights_only=False, |
| | pickle_module=pickle |
| | ) |
| | except Exception as e1: |
| | print(f"⚠️ Error con torch.load: {e1}") |
| | print("🔧 Intentando método alternativo...") |
| | try: |
| | |
| | with open(os.path.join(repo_path, "mtp3.pkl"), "rb") as f: |
| | model_data = pickle.load(f) |
| | |
| | |
| | if "model_state_dict" in model_data: |
| | for key in model_data["model_state_dict"]: |
| | if torch.is_tensor(model_data["model_state_dict"][key]): |
| | model_data["model_state_dict"][key] = model_data["model_state_dict"][key].to('cpu') |
| | except Exception as e2: |
| | print(f"❌ Error con pickle.load: {e2}") |
| | print("🔧 Intentando método final de emergencia...") |
| | |
| | with open(os.path.join(repo_path, "config.yaml"), "r") as f: |
| | import yaml |
| | config = yaml.safe_load(f) |
| | |
| | |
| | model_data = { |
| | "config": config, |
| | "model_state_dict": None |
| | } |
| |
|
| | tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model")) |
| | VOCAB_SIZE = tokenizer.vocab_size() |
| | config = model_data["config"] |
| |
|
| | |
| | use_swiglu = config.get("model", {}).get("use_swiglu", False) or "SwiGLU" in str(config) |
| |
|
| | print(f"🧠 Inicializando modelo...") |
| | print(f" → Vocabulario: {VOCAB_SIZE}") |
| | print(f" → Dimensión: {config['model']['d_model']}") |
| | print(f" → Capas: {config['model']['n_layers']}") |
| | print(f" → Cabezas: {config['model']['n_heads']}") |
| | print(f" → SwiGLU: {'✓' if use_swiglu else '✗'}") |
| |
|
| | |
| | model = MTPModel( |
| | vocab_size=VOCAB_SIZE, |
| | d_model=config['model']['d_model'], |
| | n_layers=config['model']['n_layers'], |
| | n_heads=config['model']['n_heads'], |
| | d_ff=config['model']['d_ff'], |
| | max_seq_len=config['model']['max_seq_len'], |
| | dropout=config['model'].get('dropout', 0.1) |
| | ) |
| |
|
| | |
| | if model_data["model_state_dict"] is not None: |
| | try: |
| | model.load_state_dict(model_data["model_state_dict"]) |
| | print("✅ Pesos del modelo cargados exitosamente") |
| | except Exception as e: |
| | print(f"⚠️ Error al cargar pesos: {e}") |
| | print("⚠️ Inicializando modelo con pesos aleatorios") |
| | else: |
| | print("⚠️ Inicializando modelo con pesos aleatorios (sin pesos pre-entrenados)") |
| |
|
| | model.eval() |
| |
|
| | |
| | if DEVICE == "cpu": |
| | print("⚡ Aplicando optimizaciones para CPU...") |
| | try: |
| | |
| | model = torch.quantization.quantize_dynamic( |
| | model, |
| | {torch.nn.Linear}, |
| | dtype=torch.qint8 |
| | ) |
| | print(" ✓ Cuantización aplicada") |
| | except Exception as e: |
| | print(f" ⚠ No se pudo aplicar cuantización: {e}") |
| |
|
| | model.to(DEVICE) |
| |
|
| | param_count = sum(p.numel() for p in model.parameters()) |
| | print(f"✅ Modelo inicializado: {param_count:,} parámetros ({param_count/1e6:.1f}M)") |
| |
|
| | |
| | |
| | |
| | app = FastAPI( |
| | title="MTP-3.5 API", |
| | description="API para modelo de lenguaje MTP-3.5 mejorado con RoPE, RMSNorm y SwiGLU", |
| | version="3.5" |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | class PromptRequest(BaseModel): |
| | text: str = Field(..., max_length=1000, description="Texto de entrada (instrucción)") |
| | context: str = Field(default="", description="Contexto opcional para la respuesta") |
| | max_tokens: int = Field(default=50, ge=1, le=100, description="Tokens máximos a generar") |
| | temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo") |
| | top_k: int = Field(default=40, ge=1, le=100, description="Top-k sampling") |
| | top_p: float = Field(default=0.92, ge=0.1, le=1.0, description="Top-p (nucleus) sampling") |
| | repetition_penalty: float = Field(default=1.15, ge=1.0, le=2.0, description="Penalización por repetición") |
| | min_length: int = Field(default=10, ge=1, le=50, description="Longitud mínima de respuesta") |
| |
|
| | def build_prompt(user_input: str, context: str = "") -> str: |
| | """Construye el prompt en el formato del modelo con contexto opcional""" |
| | if context and context.strip(): |
| | return f"### Instrucción:\n{user_input}\n\n### Contexto:\n{context}\n\n### Respuesta:\n" |
| | return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n" |
| |
|
| | |
| | |
| | |
| | ACTIVE_REQUESTS = 0 |
| | MAX_CONCURRENT_REQUESTS = 1 |
| |
|
| | @app.post("/generate") |
| | async def generate(req: PromptRequest): |
| | """Endpoint principal de generación de texto con control de calidad""" |
| | global ACTIVE_REQUESTS |
| | |
| | if ACTIVE_REQUESTS >= MAX_CONCURRENT_REQUESTS: |
| | return { |
| | "reply": "El servidor está ocupado. Por favor, intenta de nuevo en unos segundos.", |
| | "error": "too_many_requests", |
| | "active_requests": ACTIVE_REQUESTS |
| | } |
| | |
| | ACTIVE_REQUESTS += 1 |
| | |
| | |
| | dyn_max_tokens = min(req.max_tokens, 50) |
| | dyn_temperature = req.temperature |
| |
|
| | user_input = req.text.strip()[:500] |
| | context = req.context.strip()[:500] |
| | |
| | if not user_input: |
| | ACTIVE_REQUESTS -= 1 |
| | return {"reply": "", "tokens_generated": 0} |
| |
|
| | try: |
| | full_prompt = build_prompt(user_input, context) |
| | tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) |
| | |
| | |
| | if len(tokens) > 256: |
| | tokens = tokens[:256] |
| | print(f"⚠️ Input truncado a 256 tokens para CPU") |
| | |
| | input_ids = torch.tensor([tokens], device=DEVICE) |
| | except Exception as e: |
| | ACTIVE_REQUESTS -= 1 |
| | return {"reply": f"Error al procesar la entrada: {str(e)}", "tokens_generated": 0} |
| |
|
| | try: |
| | start_time = time.time() |
| | |
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | input_ids, |
| | max_new_tokens=dyn_max_tokens, |
| | temperature=dyn_temperature, |
| | top_k=req.top_k, |
| | top_p=req.top_p, |
| | repetition_penalty=req.repetition_penalty, |
| | min_length=req.min_length, |
| | eos_token_id=tokenizer.eos_id() |
| | ) |
| |
|
| | gen_tokens = output_ids[0, len(tokens):].tolist() |
| | |
| | |
| | safe_tokens = [] |
| | for t in gen_tokens: |
| | if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id(): |
| | safe_tokens.append(t) |
| | elif t == tokenizer.eos_id(): |
| | break |
| | |
| | response = tokenizer.decode(safe_tokens).strip() |
| | |
| | |
| | if "###" in response: |
| | response = response.split("###")[0].strip() |
| | |
| | generation_time = time.time() - start_time |
| | tokens_per_second = len(safe_tokens) / generation_time if generation_time > 0 else 0 |
| |
|
| | return { |
| | "reply": response, |
| | "tokens_generated": len(safe_tokens), |
| | "generation_time": round(generation_time, 2), |
| | "tokens_per_second": round(tokens_per_second, 1), |
| | "model": "MTP-3.5", |
| | "device": DEVICE, |
| | "context_used": bool(context), |
| | "note": "Usando CPU - respuesta limitada" if DEVICE == "cpu" else "" |
| | } |
| | |
| | except Exception as e: |
| | print(f"❌ Error durante generación: {e}") |
| | return { |
| | "reply": "Lo siento, ocurrió un error al procesar tu solicitud.", |
| | "error": str(e) |
| | } |
| | |
| | finally: |
| | ACTIVE_REQUESTS -= 1 |
| | gc.collect() |
| |
|
| | |
| | |
| | |
| | @app.get("/generate_sse") |
| | def generate_sse(): |
| | """Endpoint de streaming deshabilitado en CPU""" |
| | return StreamingResponse( |
| | iter(["data:[ERROR: Streaming deshabilitado en CPU por rendimiento]\n\n"]), |
| | media_type="text/event-stream" |
| | ) |
| |
|
| | |
| | |
| | |
| | @app.get("/health") |
| | def health_check(): |
| | """Check del estado del servicio""" |
| | return { |
| | "status": "healthy", |
| | "model": "MTP-3.5", |
| | "device": DEVICE, |
| | "active_requests": ACTIVE_REQUESTS, |
| | "max_concurrent_requests": MAX_CONCURRENT_REQUESTS, |
| | "vocab_size": VOCAB_SIZE, |
| | "parameters": sum(p.numel() for p in model.parameters()), |
| | "performance_warning": "CPU-only mode - limited performance" if DEVICE == "cpu" else None |
| | } |
| |
|
| | @app.get("/info") |
| | def model_info(): |
| | """Información detallada del modelo""" |
| | return { |
| | "model_name": "MTP-3.5", |
| | "version": "3.5", |
| | "device": DEVICE, |
| | "vocab_size": VOCAB_SIZE, |
| | "status": "running", |
| | "limitations": { |
| | "max_tokens": 50, |
| | "max_input_length": 256, |
| | "concurrent_requests": 1 |
| | } if DEVICE == "cpu" else {} |
| | } |
| |
|
| | |
| | |
| | |
| | @app.get("/", response_class=HTMLResponse) |
| | def chat_ui(): |
| | return """ |
| | <!DOCTYPE html> |
| | <html lang="es"> |
| | <head> |
| | <meta charset="UTF-8"> |
| | <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| | <title>MTP 3.5 - CPU Mode</title> |
| | <style> |
| | * { |
| | margin: 0; |
| | padding: 0; |
| | box-sizing: border-box; |
| | } |
| | body { |
| | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; |
| | background: #0f0f0f; |
| | color: #fff; |
| | height: 100vh; |
| | display: flex; |
| | flex-direction: column; |
| | } |
| | header { |
| | background: #1a1a1a; |
| | padding: 1rem; |
| | border-bottom: 1px solid #333; |
| | display: flex; |
| | justify-content: space-between; |
| | align-items: center; |
| | } |
| | .logo { |
| | display: flex; |
| | align-items: center; |
| | gap: 0.5rem; |
| | font-weight: bold; |
| | } |
| | .badge { |
| | background: #f59e0b; |
| | color: #000; |
| | padding: 0.2rem 0.5rem; |
| | border-radius: 0.5rem; |
| | font-size: 0.8rem; |
| | font-weight: bold; |
| | } |
| | .chat-container { |
| | flex: 1; |
| | overflow-y: auto; |
| | padding: 1rem; |
| | display: flex; |
| | flex-direction: column; |
| | gap: 1rem; |
| | } |
| | .message { |
| | max-width: 80%; |
| | padding: 0.8rem 1rem; |
| | border-radius: 1rem; |
| | line-height: 1.4; |
| | } |
| | .user-message { |
| | background: #2563eb; |
| | align-self: flex-end; |
| | border-bottom-right-radius: 0.2rem; |
| | } |
| | .bot-message { |
| | background: #333; |
| | align-self: flex-start; |
| | border-bottom-left-radius: 0.2rem; |
| | } |
| | .input-area { |
| | padding: 1rem; |
| | background: #1a1a1a; |
| | border-top: 1px solid #333; |
| | } |
| | .input-wrapper { |
| | display: flex; |
| | gap: 0.5rem; |
| | max-width: 800px; |
| | margin: 0 auto; |
| | } |
| | textarea { |
| | flex: 1; |
| | background: #2d2d2d; |
| | border: 1px solid #444; |
| | color: #fff; |
| | padding: 0.8rem; |
| | border-radius: 0.5rem; |
| | font-family: inherit; |
| | font-size: 1rem; |
| | resize: none; |
| | min-height: 50px; |
| | max-height: 150px; |
| | } |
| | textarea:focus { |
| | outline: none; |
| | border-color: #2563eb; |
| | } |
| | button { |
| | background: #2563eb; |
| | color: white; |
| | border: none; |
| | padding: 0 1.5rem; |
| | border-radius: 0.5rem; |
| | cursor: pointer; |
| | font-weight: bold; |
| | transition: background 0.2s; |
| | } |
| | button:hover:not(:disabled) { |
| | background: #1d4ed8; |
| | } |
| | button:disabled { |
| | background: #555; |
| | cursor: not-allowed; |
| | } |
| | .warning { |
| | text-align: center; |
| | font-size: 0.8rem; |
| | color: #f59e0b; |
| | margin-top: 0.5rem; |
| | } |
| | .typing { |
| | display: inline-block; |
| | animation: typing 1s infinite; |
| | } |
| | @keyframes typing { |
| | 0%, 100% { opacity: 1; } |
| | 50% { opacity: 0.5; } |
| | } |
| | </style> |
| | </head> |
| | <body> |
| | <header> |
| | <div class="logo"> |
| | <span>MTP 3.5</span> |
| | <span class="badge">CPU MODE</span> |
| | </div> |
| | <div style="font-size: 0.9rem; color: #aaa;"> |
| | Modelo de lenguaje optimizado para CPU |
| | </div> |
| | </header> |
| | |
| | <div class="chat-container" id="chat"> |
| | <div class="message bot-message"> |
| | ¡Hola! Soy MTP 3.5 ejecutándose en modo CPU. |
| | Mis capacidades están limitadas por rendimiento, pero estoy listo para ayudarte. |
| | <br><br> |
| | <small style="color: #f59e0b;">⚠️ Limitaciones: Máximo 50 tokens por respuesta, 1 solicitud a la vez</small> |
| | </div> |
| | </div> |
| | |
| | <div class="input-area"> |
| | <div class="input-wrapper"> |
| | <textarea |
| | id="input" |
| | placeholder="Escribe tu mensaje aquí... (Máximo 50 tokens)" |
| | rows="1" |
| | ></textarea> |
| | <button id="sendBtn">Enviar</button> |
| | </div> |
| | <div class="warning"> |
| | ⚠️ Las respuestas pueden ser lentas debido al uso de CPU |
| | </div> |
| | </div> |
| | |
| | <script> |
| | const chat = document.getElementById('chat'); |
| | const input = document.getElementById('input'); |
| | const sendBtn = document.getElementById('sendBtn'); |
| | let isGenerating = false; |
| | |
| | // Auto-resize textarea |
| | input.addEventListener('input', function() { |
| | this.style.height = 'auto'; |
| | this.style.height = Math.min(this.scrollHeight, 150) + 'px'; |
| | }); |
| | |
| | // Send message on Enter (without Shift) |
| | input.addEventListener('keydown', function(e) { |
| | if (e.key === 'Enter' && !e.shiftKey) { |
| | e.preventDefault(); |
| | sendMessage(); |
| | } |
| | }); |
| | |
| | // Send button click |
| | sendBtn.addEventListener('click', sendMessage); |
| | |
| | async function sendMessage() { |
| | const text = input.value.trim(); |
| | if (!text || isGenerating) return; |
| | |
| | // Add user message |
| | addMessage(text, 'user'); |
| | input.value = ''; |
| | input.style.height = 'auto'; |
| | |
| | // Disable input |
| | isGenerating = true; |
| | input.disabled = true; |
| | sendBtn.disabled = true; |
| | sendBtn.textContent = 'Procesando...'; |
| | |
| | try { |
| | // Show typing indicator |
| | const typingMsg = addMessage('<span class="typing">MTP está pensando...</span>', 'bot'); |
| | |
| | // Send request |
| | const response = await fetch('/generate', { |
| | method: 'POST', |
| | headers: { |
| | 'Content-Type': 'application/json' |
| | }, |
| | body: JSON.stringify({ |
| | text: text, |
| | context: '', |
| | max_tokens: 50, |
| | temperature: 0.7, |
| | top_k: 40, |
| | top_p: 0.92, |
| | repetition_penalty: 1.15, |
| | min_length: 10 |
| | }) |
| | }); |
| | |
| | const data = await response.json(); |
| | |
| | // Remove typing indicator |
| | typingMsg.remove(); |
| | |
| | // Add bot response |
| | addMessage(data.reply || 'No pude generar una respuesta.', 'bot'); |
| | |
| | // Show stats if available |
| | if (data.tokens_generated) { |
| | const stats = document.createElement('div'); |
| | stats.style.fontSize = '0.8rem'; |
| | stats.style.color = '#888'; |
| | stats.style.marginTop = '0.5rem'; |
| | stats.textContent = `${data.tokens_generated} tokens • ${data.tokens_per_second || '0'} t/s • ${data.generation_time || '?'}s`; |
| | |
| | const lastBotMsg = chat.querySelector('.bot-message:last-child'); |
| | if (lastBotMsg) { |
| | lastBotMsg.appendChild(stats); |
| | } |
| | } |
| | |
| | } catch (error) { |
| | console.error('Error:', error); |
| | const errorMsg = document.querySelector('.typing'); |
| | if (errorMsg) errorMsg.remove(); |
| | addMessage('Error de conexión. Intenta nuevamente.', 'bot'); |
| | } finally { |
| | // Re-enable input |
| | isGenerating = false; |
| | input.disabled = false; |
| | sendBtn.disabled = false; |
| | sendBtn.textContent = 'Enviar'; |
| | input.focus(); |
| | } |
| | } |
| | |
| | function addMessage(text, sender) { |
| | const msg = document.createElement('div'); |
| | msg.className = `message ${sender}-message`; |
| | msg.innerHTML = text; |
| | chat.appendChild(msg); |
| | chat.scrollTop = chat.scrollHeight; |
| | return msg; |
| | } |
| | </script> |
| | </body> |
| | </html> |
| | """ |
| |
|
| | if __name__ == "__main__": |
| | port = int(os.environ.get("PORT", 7860)) |
| | print(f"\n🚀 Iniciando servidor MTP-3.5 en modo CPU...") |
| | print(f"🌐 Interfaz web: http://0.0.0.0:{port}") |
| | print(f"📡 API docs: http://0.0.0.0:{port}/docs") |
| | print(f"📊 Health check: http://0.0.0.0:{port}/health") |
| | print(f"\n⚠️ ADVERTENCIA: Ejecutando en CPU - rendimiento limitado") |
| | print(f"⚠️ Límites: 50 tokens máx, 256 tokens entrada, 1 request concurrente") |
| | print(f"\n✅ Sistema listo. Presiona Ctrl+C para detener.") |
| | |
| | uvicorn.run( |
| | app, |
| | host="0.0.0.0", |
| | port=port, |
| | log_level="info" |
| | ) |