Teszen_AI / app.py
teszenofficial's picture
Update app.py
2632d84 verified
import os
import sys
import torch
import pickle
import time
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel
from huggingface_hub import snapshot_download
import uvicorn
# ======================
# OPTIMIZACIÓN CPU
# ======================
torch.set_num_threads(max(1, os.cpu_count() // 2))
torch.set_grad_enabled(False)
# ======================
# DISPOSITIVO
# ======================
if torch.cuda.is_available():
DEVICE = "cuda"
print("✅ GPU detectada. Usando CUDA.")
else:
DEVICE = "cpu"
print("⚠️ GPU no detectada. Usando CPU.")
MODEL_REPO = "teszenofficial/mtp1"
# ======================
# DESCARGA MODELO
# ======================
print("--- SISTEMA MTP 1.1 ---")
repo_path = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir="mtptz_repo"
)
sys.path.insert(0, repo_path)
from model import MTPMiniModel
from tokenizer import MTPTokenizer
# ======================
# CARGA MODELO
# ======================
print("Cargando modelo...")
pkl_file = next(f for f in os.listdir(repo_path) if f.endswith(".pkl"))
with open(os.path.join(repo_path, pkl_file), "rb") as f:
model_data = pickle.load(f)
tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model"))
config = model_data["config"]
model = MTPMiniModel(
vocab_size=model_data["vocab_size"],
d_model=config["model"]["d_model"],
n_layers=config["model"]["n_layers"],
n_heads=config["model"]["n_heads"],
d_ff=config["model"]["d_ff"],
max_seq_len=config["model"]["max_seq_len"],
dropout=0.0
)
model.load_state_dict(model_data["model_state_dict"])
model.to(DEVICE)
model.eval()
VOCAB_SIZE = tokenizer.sp.get_piece_size()
model.vocab_size = VOCAB_SIZE
print(f"🚀 MTP 1.1 listo en {DEVICE.upper()}")
# ======================
# FASTAPI
# ======================
app = FastAPI(title="MTP 1.1")
class Prompt(BaseModel):
text: str
# ======================
# GENERACIÓN NORMAL (NO STREAM)
# ======================
@app.post("/generate")
def generate(prompt: Prompt):
try:
text = prompt.text.strip()
if not text:
return {"reply": ""}
full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n"
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=80,
temperature=0.7,
top_k=50,
top_p=0.9
)
gen = output[0, len(tokens):].tolist()
safe = [t for t in gen if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id()]
reply = tokenizer.decode(safe).strip()
return {"reply": reply}
except Exception as e:
print("❌ ERROR:", e)
return {"reply": "Error interno."}
# ======================
# GENERACIÓN STREAMING (TIPO CHATGPT)
# ======================
@app.post("/generate_stream")
def generate_stream(prompt: Prompt):
def stream():
try:
text = prompt.text.strip()
full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n"
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
for _ in range(80):
with torch.no_grad():
logits = model(input_ids)[:, -1, :]
logits = logits[:, :VOCAB_SIZE]
probs = torch.softmax(logits / 0.7, dim=-1)
next_id = torch.argmax(probs, dim=-1).item()
if next_id == tokenizer.eos_id():
break
if 0 <= next_id < VOCAB_SIZE:
token_text = tokenizer.decode([next_id])
yield token_text
input_ids = torch.cat(
[input_ids, torch.tensor([[next_id]], device=DEVICE)],
dim=1
)
time.sleep(0.015)
except Exception as e:
print("❌ STREAM ERROR:", e)
yield "\n[error]"
return StreamingResponse(stream(), media_type="text/plain")
# ======================
# FRONTEND HTML COMPLETO
# ======================
@app.get("/", response_class=HTMLResponse)
def ui():
return """
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>MTP 1.1</title>
<style>
body{margin:0;background:#131314;color:#e3e3e3;font-family:Inter,system-ui}
#chat{max-width:900px;margin:auto;padding:20px}
.msg{margin:12px 0;white-space:pre-wrap}
.user{color:#8ab4f8}
.bot{color:#e3e3e3}
input{width:100%;padding:12px;border-radius:10px;border:none;background:#1e1f20;color:white}
button{margin-top:10px;padding:10px;border-radius:10px;border:none;background:#4a9eff;color:black;font-weight:bold}
</style>
</head>
<body>
<div id="chat">
<div class="msg bot">Hola, soy MTP 1.1.</div>
</div>
<input id="inp" placeholder="Escribe algo…" />
<button onclick="send()">Enviar</button>
<script>
async function send(){
const input=document.getElementById('inp');
const text=input.value.trim();
if(!text)return;
input.value="";
const chat=document.getElementById('chat');
chat.innerHTML+=`<div class="msg user">${text}</div>`;
const bot=document.createElement('div');
bot.className="msg bot";
chat.appendChild(bot);
const res=await fetch('/generate_stream',{
method:'POST',
headers:{'Content-Type':'application/json'},
body:JSON.stringify({text})
});
const reader=res.body.getReader();
const decoder=new TextDecoder();
while(true){
const {value,done}=await reader.read();
if(done)break;
bot.textContent+=decoder.decode(value);
window.scrollTo(0,document.body.scrollHeight);
}
}
</script>
</body>
</html>
"""
# ======================
# ENTRYPOINT
# ======================
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)