Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch, re, asyncio, aiohttp, os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| MODEL_ID = os.getenv("MODEL_ID", "ai-forever/mGPT-1.3B-persian") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # کممصرف روی CPU | |
| torch.set_num_threads(1) | |
| app = FastAPI() | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True | |
| ).to(device).eval() | |
| class Req(BaseModel): | |
| prompt: str | |
| max_tokens: int = 160 | |
| system: str = "تو یه دستیار فارسی خودمونی و سریع هستی؛ جوابها کوتاه، رک و بامزه (۱–۲ جمله)." | |
| temperature: float = 0.65 | |
| def health(): | |
| return {"ok": True} | |
| def root(): | |
| return {"ok": True, "use": "POST /generate"} | |
| def _clean(txt: str) -> str: | |
| txt = txt.replace("[دستیار]:", "").replace("[سیستم]:", "").replace("[کاربر]:", "") | |
| txt = re.sub(r"\[[^\]\n]{0,12}\]:", "", txt).strip() | |
| parts = re.split(r"(?<=[.!؟?])\s+", txt) | |
| short = " ".join(parts[:2]).strip() or txt | |
| return short[:220] | |
| def generate(r: Req): | |
| sys = (r.system or "")[:400] | |
| user = r.prompt[:900] | |
| text_in = f"[سیستم]: {sys}\n[کاربر]: {user}\n[دستیار]:" | |
| inputs = tok(text_in, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=min(200, r.max_tokens), | |
| do_sample=True, | |
| temperature=r.temperature, | |
| top_p=0.9, | |
| repetition_penalty=1.12, | |
| eos_token_id=tok.eos_token_id or tok.pad_token_id, | |
| pad_token_id=tok.eos_token_id or tok.pad_token_id, | |
| ) | |
| raw = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return {"text": _clean(raw)} | |
| # Keep-alive داخلی (برای بیدار ماندن Space) | |
| async def _keepalive(): | |
| await asyncio.sleep(5) | |
| async with aiohttp.ClientSession() as s: | |
| while True: | |
| try: | |
| await s.get("http://127.0.0.1:7860/health", timeout=5) | |
| except Exception: | |
| pass | |
| await asyncio.sleep(300) | |
| async def _on_startup(): | |
| asyncio.create_task(_keepalive()) |