import os import sys import torch import json import time import gc import re from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from huggingface_hub import snapshot_download import uvicorn import math import torch.nn as nn import torch.nn.functional as F import sentencepiece as spm # ====================== # CONFIGURACIÓN DE DISPOSITIVO # ====================== if torch.cuda.is_available(): DEVICE = "cuda" print("✅ GPU NVIDIA detectada. Usando CUDA.") else: DEVICE = "cpu" print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).") if DEVICE == "cpu": torch.set_num_threads(max(1, os.cpu_count() // 2)) torch.set_grad_enabled(False) MODEL_REPO = "TeszenAI/MTP-2.5" # ====================== # DEFINIR ARQUITECTURA DEL MODELO (MTP-1.1) # ====================== class LayerNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.bias = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.weight * (x - mean) / (std + self.eps) + self.bias class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.d_k) def forward(self, x, mask=None): batch_size, seq_len, _ = x.shape Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, V) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) return self.w_o(attn_output) class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.linear2(self.dropout(F.gelu(self.linear1(x)))) class TransformerBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.attention = MultiHeadAttention(d_model, n_heads, dropout) self.feed_forward = FeedForward(d_model, d_ff, dropout) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, x, mask=None): attn_output = self.attention(x, mask) x = x + self.dropout1(attn_output) x = self.norm1(x) ff_output = self.feed_forward(x) x = x + self.dropout2(ff_output) x = self.norm2(x) return x class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1), :] class MTPModel(nn.Module): def __init__(self, vocab_size: int, d_model: int = 128, n_heads: int = 4, n_layers: int = 4, d_ff: int = 512, dropout: float = 0.1, max_len: int = 256): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.max_len = max_len self.token_embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_len) self.blocks = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) self.norm = LayerNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size) def forward(self, x, mask=None): if mask is None: mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device) x = self.token_embedding(x) * math.sqrt(self.d_model) x = self.pos_encoding(x) for block in self.blocks: x = block(x, mask) x = self.norm(x) logits = self.lm_head(x) return logits def generate(self, input_ids, max_new_tokens=100, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1, eos_token_id=3): """Método de generación mejorado con parada limpia""" generated = input_ids eos_detected = False for _ in range(max_new_tokens): with torch.no_grad(): logits = self(generated) next_logits = logits[0, -1, :] / temperature # Repetition penalty if repetition_penalty != 1.0: for token_id in set(generated[0].tolist()): next_logits[token_id] /= repetition_penalty # Top-k if top_k > 0: indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None] next_logits[indices_to_remove] = float('-inf') # Top-p if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() # Detener en EOS o tokens sospechosos if next_token == eos_token_id: eos_detected = True break # Detener si detectamos repetición excesiva del mismo token if len(generated[0]) > 10: last_tokens = generated[0][-10:].tolist() if len(set(last_tokens)) == 1: break generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1) return generated, eos_detected # ====================== # DESCARGA Y CARGA DEL MODELO # ====================== print(f"📦 Descargando modelo desde {MODEL_REPO}...") repo_path = snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo" ) # Cargar configuración config_path = os.path.join(repo_path, "config.json") if os.path.exists(config_path): with open(config_path, "r") as f: config = json.load(f) else: config = { "vocab_size": 5000, "d_model": 128, "n_heads": 4, "n_layers": 4, "d_ff": 512, "dropout": 0.1, "max_len": 256 } # Cargar tokenizador tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model") sp = spm.SentencePieceProcessor() sp.load(tokenizer_path) VOCAB_SIZE = sp.get_piece_size() EOS_TOKEN_ID = sp.eos_id() BOS_TOKEN_ID = sp.bos_id() # Actualizar vocab_size en config config["vocab_size"] = VOCAB_SIZE print(f"🧠 Inicializando modelo MTP-1.1...") print(f" → Vocabulario: {VOCAB_SIZE}") print(f" → EOS token ID: {EOS_TOKEN_ID}") print(f" → BOS token ID: {BOS_TOKEN_ID}") print(f" → Dimensión: {config['d_model']}") print(f" → Capas: {config['n_layers']}") print(f" → Heads: {config['n_heads']}") model = MTPModel(**config) model.to(DEVICE) # Cargar pesos del modelo model_path = os.path.join(repo_path, "mtp_model.pt") if os.path.exists(model_path): state_dict = torch.load(model_path, map_location=DEVICE) model.load_state_dict(state_dict) print("✅ Pesos del modelo cargados") else: print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios") model.eval() # Cuantización para CPU if DEVICE == "cpu": print("⚡ Aplicando cuantización dinámica para CPU...") model = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 ) param_count = sum(p.numel() for p in model.parameters()) print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)") # ====================== # FUNCIONES DE LIMPIEZA DE RESPUESTAS # ====================== def clean_response(text: str, original_prompt: str = None) -> str: """Limpia la respuesta generada eliminando basura y repeticiones""" if not text: return "Lo siento, no pude generar una respuesta." # Eliminar el prompt original si aparece al inicio if original_prompt: prompt_clean = original_prompt.strip().lower() text_lower = text.lower() if text_lower.startswith(prompt_clean): text = text[len(original_prompt):].strip() elif prompt_clean in text_lower[:50]: # Buscar después del prompt idx = text_lower.find(prompt_clean) if idx != -1: text = text[idx + len(original_prompt):].strip() # Eliminar partes que contienen "###" if "###" in text: text = text.split("###")[0].strip() # Eliminar repeticiones absurdas (patrones como "xxx" repetido) words = text.split() if len(words) > 10: unique_words = [] last_word = None repeat_count = 0 for w in words: if w == last_word: repeat_count += 1 if repeat_count > 2: continue else: repeat_count = 0 unique_words.append(w) last_word = w text = " ".join(unique_words) # Eliminar fragmentos que parecen basura (patrones sin sentido) garbage_patterns = [ r'[a-z]{20,}', # Palabras muy largas sin sentido r'\d{5,}', # Números muy largos r'[^\w\s\.\,\!\?\-áéíóúüñ]{10,}', # Caracteres extraños repetidos ] for pattern in garbage_patterns: text = re.sub(pattern, '', text) # Limpiar espacios múltiples text = re.sub(r'\s+', ' ', text).strip() # Capitalizar primera letra if text and len(text) > 0: text = text[0].upper() + text[1:] if len(text) > 1 else text.upper() # Si la respuesta es demasiado corta o vacía, dar mensaje por defecto if len(text) < 3: return "Entendido. ¿Algo más en lo que pueda ayudarte?" return text # ====================== # API CONFIG # ====================== app = FastAPI( title="MTP-1.1 API", description="API para modelo de lenguaje MTP-1.1", version="1.1" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class PromptRequest(BaseModel): text: str = Field(..., max_length=2000, description="Texto de entrada") max_tokens: int = Field(default=100, ge=10, le=200, description="Tokens máximos a generar") temperature: float = Field(default=0.7, ge=0.1, le=1.5, description="Temperatura de muestreo") top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling") top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling") repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición") def build_prompt(user_input: str) -> str: """Construye el prompt en el formato del modelo""" return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n" # ====================== # GESTIÓN DE CARGA # ====================== ACTIVE_REQUESTS = 0 class MTPTokenizer: """Wrapper para el tokenizador de SentencePiece""" def __init__(self, sp_model): self.sp = sp_model def encode(self, text): return self.sp.encode(text) def decode(self, tokens): return self.sp.decode(tokens) def bos_id(self): return self.sp.bos_id() def eos_id(self): return self.sp.eos_id() tokenizer_wrapper = MTPTokenizer(sp) @app.post("/generate") async def generate(req: PromptRequest): """Endpoint principal de generación de texto""" global ACTIVE_REQUESTS ACTIVE_REQUESTS += 1 try: user_input = req.text.strip() if not user_input: return {"reply": "", "tokens_generated": 0} # Construir prompt full_prompt = build_prompt(user_input) tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) # Parámetros dinámicos según carga dyn_max_tokens = req.max_tokens dyn_temperature = req.temperature if ACTIVE_REQUESTS > 2: dyn_max_tokens = min(dyn_max_tokens, 80) dyn_temperature = max(0.5, dyn_temperature * 0.9) # Generar with torch.no_grad(): output_ids, eos_detected = model.generate( input_ids, max_new_tokens=dyn_max_tokens, temperature=dyn_temperature, top_k=req.top_k, top_p=req.top_p, repetition_penalty=req.repetition_penalty, eos_token_id=tokenizer_wrapper.eos_id() ) # Extraer solo los tokens generados (excluyendo el prompt) gen_tokens = output_ids[0, len(tokens):].tolist() # Filtrar tokens inválidos safe_tokens = [ t for t in gen_tokens if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id() ] # Decodificar raw_response = tokenizer_wrapper.decode(safe_tokens).strip() # Limpiar respuesta clean_reply = clean_response(raw_response, user_input) # Si EOS no fue detectado y la respuesta parece incompleta, truncar if not eos_detected and len(clean_reply) > 200: # Buscar un punto final para truncar last_period = clean_reply.rfind('.') if last_period > 100: clean_reply = clean_reply[:last_period + 1] # Eliminar frases sin sentido comunes nonsense_phrases = [ "foompañances", "ciudadores", "mejtedon", "calportedon", "rápidodcor", "rápidodarse", "miel", "baon", "domol" ] for phrase in nonsense_phrases: clean_reply = clean_reply.replace(phrase, "") # Limpiar espacios dobles nuevamente clean_reply = re.sub(r'\s+', ' ', clean_reply).strip() # Si la respuesta sigue siendo muy larga y no tiene puntos, cortar if len(clean_reply) > 300 and '.' not in clean_reply[-50:]: clean_reply = clean_reply[:250] + "..." return { "reply": clean_reply, "tokens_generated": len(safe_tokens), "model": "MTP-1.1", "eos_detected": eos_detected } except Exception as e: print(f"❌ Error durante generación: {e}") return { "reply": "Lo siento, ocurrió un error al procesar tu solicitud.", "error": str(e) } finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() # ====================== # ENDPOINTS DE INFORMACIÓN # ====================== @app.get("/health") def health_check(): return { "status": "healthy", "model": "MTP-1.1", "device": DEVICE, "active_requests": ACTIVE_REQUESTS, "vocab_size": VOCAB_SIZE } @app.get("/info") def model_info(): return { "model_name": "MTP-1.1", "version": "1.1", "architecture": config, "parameters": sum(p.numel() for p in model.parameters()), "device": DEVICE } # ====================== # INTERFAZ WEB (MODERNA) # ====================== @app.get("/", response_class=HTMLResponse) def chat_ui(): return """