Spaces:
Sleeping
Sleeping
File size: 2,493 Bytes
fc713d0 9bd0be6 fc713d0 dd43c02 fc713d0 9bd0be6 fc713d0 9bd0be6 fc713d0 9bd0be6 fc713d0 9bd0be6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from fastapi import FastAPI
from pydantic import BaseModel
import torch, re, asyncio, aiohttp, os
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = os.getenv("MODEL_ID", "ai-forever/mGPT-1.3B-persian")
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# کممصرف روی CPU
torch.set_num_threads(1)
app = FastAPI()
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
low_cpu_mem_usage=True
).to(device).eval()
class Req(BaseModel):
prompt: str
max_tokens: int = 160
system: str = "تو یه دستیار فارسی خودمونی و سریع هستی؛ جوابها کوتاه، رک و بامزه (۱–۲ جمله)."
temperature: float = 0.65
@app.get("/health")
def health():
return {"ok": True}
@app.get("/")
def root():
return {"ok": True, "use": "POST /generate"}
def _clean(txt: str) -> str:
txt = txt.replace("[دستیار]:", "").replace("[سیستم]:", "").replace("[کاربر]:", "")
txt = re.sub(r"\[[^\]\n]{0,12}\]:", "", txt).strip()
parts = re.split(r"(?<=[.!؟?])\s+", txt)
short = " ".join(parts[:2]).strip() or txt
return short[:220]
@app.post("/generate")
def generate(r: Req):
sys = (r.system or "")[:400]
user = r.prompt[:900]
text_in = f"[سیستم]: {sys}\n[کاربر]: {user}\n[دستیار]:"
inputs = tok(text_in, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=min(200, r.max_tokens),
do_sample=True,
temperature=r.temperature,
top_p=0.9,
repetition_penalty=1.12,
eos_token_id=tok.eos_token_id or tok.pad_token_id,
pad_token_id=tok.eos_token_id or tok.pad_token_id,
)
raw = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return {"text": _clean(raw)}
# Keep-alive داخلی (برای بیدار ماندن Space)
async def _keepalive():
await asyncio.sleep(5)
async with aiohttp.ClientSession() as s:
while True:
try:
await s.get("http://127.0.0.1:7860/health", timeout=5)
except Exception:
pass
await asyncio.sleep(300)
@app.on_event("startup")
async def _on_startup():
asyncio.create_task(_keepalive()) |