import os import sys import torch import pickle import time import gc import asyncio import aiohttp from typing import Optional, Dict, List from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from huggingface_hub import snapshot_download import uvicorn import wikipedia from duckduckgo_search import DDGS # ====================== # CONFIGURACIÓN DE DISPOSITIVO # ====================== if torch.cuda.is_available(): DEVICE = "cuda" print("✅ GPU NVIDIA detectada. Usando CUDA.") else: DEVICE = "cpu" print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).") if DEVICE == "cpu": torch.set_num_threads(max(1, os.cpu_count() // 2)) torch.set_grad_enabled(False) MODEL_REPO = "TeszenAI/MTPw" # ====================== # MOTOR DE BÚSQUEDA WEB # ====================== class WebSearchEngine: """Motor de búsqueda integrado con Wikipedia y DuckDuckGo""" def __init__(self): self.ddgs = DDGS() wikipedia.set_lang('es') async def search_wikipedia(self, query: str, sentences: int = 4) -> Optional[Dict]: """Buscar en Wikipedia""" try: search_results = wikipedia.search(query, results=3) if not search_results: return None page = wikipedia.page(search_results[0], auto_suggest=False) summary = wikipedia.summary(search_results[0], sentences=sentences) return { "source": "Wikipedia", "title": page.title, "url": page.url, "summary": summary, "success": True } except wikipedia.exceptions.DisambiguationError as e: try: page = wikipedia.page(e.options[0], auto_suggest=False) summary = wikipedia.summary(e.options[0], sentences=sentences) return { "source": "Wikipedia", "title": page.title, "url": page.url, "summary": summary, "success": True } except: return None except: return None async def search_duckduckgo(self, query: str, max_results: int = 5) -> List[Dict]: """Buscar en DuckDuckGo""" try: results = [] ddg_results = self.ddgs.text(query, max_results=max_results) for r in ddg_results: results.append({ "title": r.get("title", ""), "url": r.get("href", ""), "snippet": r.get("body", "") }) return results except Exception as e: print(f"Error en DuckDuckGo: {e}") return [] async def search(self, query: str) -> Dict: """Búsqueda combinada""" wiki_task = self.search_wikipedia(query) ddg_task = self.search_duckduckgo(query) wiki_result, ddg_results = await asyncio.gather(wiki_task, ddg_task) return { "query": query, "wikipedia": wiki_result, "web_results": ddg_results, "timestamp": time.time() } # Inicializar motor de búsqueda search_engine = WebSearchEngine() # ====================== # DESCARGA Y CARGA DEL MODELO # ====================== print(f"📦 Descargando modelo desde {MODEL_REPO}...") repo_path = snapshot_download( repo_id=MODEL_REPO, repo_type="model", local_dir="mtptz_repo" ) sys.path.insert(0, repo_path) from model import MTPMiniModel from tokenizer import MTPTokenizer print("🔧 Cargando tensores y configuración...") with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f: model_data = pickle.load(f) tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model")) VOCAB_SIZE = tokenizer.sp.get_piece_size() config = model_data["config"] use_swiglu = config["model"].get("use_swiglu", False) print(f"🧠 Inicializando modelo...") print(f" → Vocabulario: {VOCAB_SIZE}") print(f" → Dimensión: {config['model']['d_model']}") print(f" → Capas: {config['model']['n_layers']}") print(f" → Cabezas: {config['model']['n_heads']}") print(f" → SwiGLU: {'✓' if use_swiglu else '✗'}") model = MTPMiniModel( vocab_size=VOCAB_SIZE, d_model=config["model"]["d_model"], n_layers=config["model"]["n_layers"], n_heads=config["model"]["n_heads"], d_ff=config["model"]["d_ff"], max_seq_len=config["model"]["max_seq_len"], dropout=0.0, use_swiglu=use_swiglu ) model.load_state_dict(model_data["model_state_dict"]) model.eval() if DEVICE == "cpu": print("⚡ Aplicando cuantización dinámica para CPU...") model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) model.to(DEVICE) param_count = sum(p.numel() for p in model.parameters()) print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)") print(f"🔍 Motor de búsqueda web inicializado (Wikipedia + DuckDuckGo)") # ====================== # API CONFIG # ====================== app = FastAPI( title="MTP-3.5 Enhanced API", description="API mejorada con capacidades de búsqueda web (Wikipedia + DuckDuckGo)", version="3.5-web" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class PromptRequest(BaseModel): text: str = Field(..., max_length=2000, description="Texto de entrada") max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar") temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo") top_k: int = Field(default=40, ge=1, le=100, description="Top-k sampling") top_p: float = Field(default=0.92, ge=0.1, le=1.0, description="Top-p (nucleus) sampling") repetition_penalty: float = Field(default=1.15, ge=1.0, le=2.0, description="Penalización por repetición") min_length: int = Field(default=20, ge=5, le=100, description="Longitud mínima de respuesta") use_web_search: bool = Field(default=False, description="Activar búsqueda web") class SearchRequest(BaseModel): query: str = Field(..., max_length=500, description="Consulta de búsqueda") def build_prompt(user_input: str, web_context: Optional[str] = None) -> str: """Construye el prompt con contexto web opcional""" if web_context: # Formato especial para búsqueda web con instrucciones claras return f"""### Instrucción del sistema: Eres un asistente que busca información en internet. Debes resumir la información encontrada de forma clara y útil. ### Información encontrada en la web: {web_context} ### Pregunta del usuario: {user_input} ### Respuesta: Hola, encontré esto en la web: """ return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n" def format_search_results(search_data: Dict) -> str: """Formatea resultados de búsqueda para el contexto del modelo""" context_parts = [] if search_data.get("wikipedia") and search_data["wikipedia"].get("success"): wiki = search_data["wikipedia"] context_parts.append(f"[Wikipedia - {wiki['title']}]\n{wiki['summary']}\nFuente: {wiki['url']}") if search_data.get("web_results"): for i, result in enumerate(search_data["web_results"][:4], 1): snippet = result['snippet'][:300].strip() context_parts.append(f"[Resultado {i}: {result['title']}]\n{snippet}\nFuente: {result['url']}") return "\n\n".join(context_parts) if context_parts else "" # ====================== # ⚡ GESTIÓN DE CARGA # ====================== ACTIVE_REQUESTS = 0 MAX_CONCURRENT_REQUESTS = 3 @app.post("/search") async def web_search(req: SearchRequest): """Endpoint de búsqueda web""" try: search_results = await search_engine.search(req.query) formatted_context = format_search_results(search_results) return { "query": req.query, "results": search_results, "formatted_context": formatted_context, "has_results": bool(formatted_context), "sources_used": [] } except Exception as e: print(f"Error en búsqueda: {e}") return { "query": req.query, "error": str(e), "has_results": False } @app.post("/generate") async def generate(req: PromptRequest): """Endpoint principal con búsqueda web integrada""" global ACTIVE_REQUESTS if ACTIVE_REQUESTS >= MAX_CONCURRENT_REQUESTS: return { "reply": "El servidor está ocupado. Por favor, intenta de nuevo en unos segundos.", "error": "too_many_requests", "active_requests": ACTIVE_REQUESTS } ACTIVE_REQUESTS += 1 dyn_max_tokens = req.max_tokens dyn_temperature = req.temperature if ACTIVE_REQUESTS > 1: print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.") dyn_max_tokens = min(dyn_max_tokens, 120) dyn_temperature = max(0.6, dyn_temperature * 0.95) user_input = req.text.strip() if not user_input: ACTIVE_REQUESTS -= 1 return {"reply": "", "tokens_generated": 0} web_context = "" search_results = None # Realizar búsqueda web si está activada if req.use_web_search: try: # Extraer la consulta de búsqueda del mensaje del usuario search_query = user_input # Si el mensaje es una pregunta larga, extraer palabras clave if len(user_input.split()) > 8: # Usar las primeras palabras más relevantes words = user_input.lower().split() # Filtrar palabras comunes stop_words = {'qué', 'cuál', 'cómo', 'dónde', 'cuándo', 'por', 'para', 'el', 'la', 'los', 'las', 'un', 'una', 'es', 'sobre', 'me', 'puedes', 'explicar', 'decir', 'información'} keywords = [w for w in words if w not in stop_words][:5] search_query = ' '.join(keywords) search_results = await search_engine.search(search_query) web_context = format_search_results(search_results) if web_context: print(f"🔍 Búsqueda web completada para: '{search_query}'") print(f" Contexto agregado: {len(web_context)} caracteres") except Exception as e: print(f"Error en búsqueda web: {e}") full_prompt = build_prompt(user_input, web_context if web_context else None) tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens], device=DEVICE) try: start_time = time.time() with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=dyn_max_tokens, temperature=dyn_temperature, top_k=req.top_k, top_p=req.top_p, repetition_penalty=req.repetition_penalty, min_length=req.min_length, eos_token_id=tokenizer.eos_id() ) gen_tokens = output_ids[0, len(tokens):].tolist() safe_tokens = [] for t in gen_tokens: if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id(): safe_tokens.append(t) elif t == tokenizer.eos_id(): break response = tokenizer.decode(safe_tokens).strip() if "###" in response: response = response.split("###")[0].strip() if response.endswith(("...", ". . .", "…")): response = response.rstrip(".") generation_time = time.time() - start_time tokens_per_second = len(safe_tokens) / generation_time if generation_time > 0 else 0 result = { "reply": response, "tokens_generated": len(safe_tokens), "generation_time": round(generation_time, 2), "tokens_per_second": round(tokens_per_second, 1), "model": "MTP-3.5", "device": DEVICE, "web_search_used": req.use_web_search } if req.use_web_search and search_results: sources = [] if search_results.get("wikipedia") and search_results["wikipedia"].get("success"): sources.append({ "type": "wikipedia", "title": search_results["wikipedia"]["title"], "url": search_results["wikipedia"]["url"] }) if search_results.get("web_results"): for r in search_results["web_results"][:3]: sources.append({ "type": "web", "title": r["title"], "url": r["url"] }) result["sources"] = sources result["search_query"] = user_input return result except Exception as e: print(f"❌ Error durante generación: {e}") import traceback traceback.print_exc() return { "reply": "Lo siento, ocurrió un error al procesar tu solicitud.", "error": str(e) } finally: ACTIVE_REQUESTS -= 1 if DEVICE == "cuda": torch.cuda.empty_cache() gc.collect() # ====================== # 📊 ENDPOINTS DE INFORMACIÓN # ====================== @app.get("/health") def health_check(): """Check del estado del servicio""" memory_info = {} if DEVICE == "cuda": memory_info = { "gpu_memory_allocated_mb": round(torch.cuda.memory_allocated() / 1024**2, 2), "gpu_memory_reserved_mb": round(torch.cuda.memory_reserved() / 1024**2, 2) } return { "status": "healthy", "model": "MTP-3.5-Web", "device": DEVICE, "active_requests": ACTIVE_REQUESTS, "max_concurrent_requests": MAX_CONCURRENT_REQUESTS, "vocab_size": VOCAB_SIZE, "parameters": sum(p.numel() for p in model.parameters()), "web_search_enabled": True, **memory_info } @app.get("/info") def model_info(): """Información detallada del modelo""" improvements = [ "RoPE (Rotary Position Embedding)", "RMSNorm (Root Mean Square Normalization)", "Label Smoothing (0.1)", "Repetition Penalty", "Early Stopping", "EOS Loss Weight", "Length Control", "Gradient Accumulation", "Web Search Integration (Wikipedia + DuckDuckGo)" ] if config["model"].get("use_swiglu", False): improvements.append("SwiGLU Activation") return { "model_name": "MTP-3.5-Web", "version": "3.5-web", "architecture": { "d_model": config["model"]["d_model"], "n_layers": config["model"]["n_layers"], "n_heads": config["model"]["n_heads"], "d_ff": config["model"]["d_ff"], "max_seq_len": config["model"]["max_seq_len"], "vocab_size": VOCAB_SIZE, "use_swiglu": config["model"].get("use_swiglu", False), "dropout": config["model"]["dropout"] }, "parameters": sum(p.numel() for p in model.parameters()), "parameters_human": f"{sum(p.numel() for p in model.parameters())/1e6:.1f}M", "device": DEVICE, "improvements": improvements, "web_search": { "enabled": True, "sources": ["Wikipedia (ES)", "DuckDuckGo"] } } # ====================== # 🎨 INTERFAZ WEB MEJORADA CON BÚSQUEDA # ====================== @app.get("/", response_class=HTMLResponse) def chat_ui(): return """