from fastapi import FastAPI from pydantic import BaseModel import torch, re, asyncio, aiohttp, os from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = os.getenv("MODEL_ID", "ai-forever/mGPT-1.3B-persian") device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # کم‌مصرف روی CPU torch.set_num_threads(1) app = FastAPI() tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=dtype, low_cpu_mem_usage=True ).to(device).eval() class Req(BaseModel): prompt: str max_tokens: int = 160 system: str = "تو یه دستیار فارسی خودمونی و سریع هستی؛ جواب‌ها کوتاه، رک و بامزه (۱–۲ جمله)." temperature: float = 0.65 @app.get("/health") def health(): return {"ok": True} @app.get("/") def root(): return {"ok": True, "use": "POST /generate"} def _clean(txt: str) -> str: txt = txt.replace("[دستیار]:", "").replace("[سیستم]:", "").replace("[کاربر]:", "") txt = re.sub(r"\[[^\]\n]{0,12}\]:", "", txt).strip() parts = re.split(r"(?<=[.!؟?])\s+", txt) short = " ".join(parts[:2]).strip() or txt return short[:220] @app.post("/generate") def generate(r: Req): sys = (r.system or "")[:400] user = r.prompt[:900] text_in = f"[سیستم]: {sys}\n[کاربر]: {user}\n[دستیار]:" inputs = tok(text_in, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=min(200, r.max_tokens), do_sample=True, temperature=r.temperature, top_p=0.9, repetition_penalty=1.12, eos_token_id=tok.eos_token_id or tok.pad_token_id, pad_token_id=tok.eos_token_id or tok.pad_token_id, ) raw = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return {"text": _clean(raw)} # Keep-alive داخلی (برای بیدار ماندن Space) async def _keepalive(): await asyncio.sleep(5) async with aiohttp.ClientSession() as s: while True: try: await s.get("http://127.0.0.1:7860/health", timeout=5) except Exception: pass await asyncio.sleep(300) @app.on_event("startup") async def _on_startup(): asyncio.create_task(_keepalive())