File size: 2,493 Bytes
fc713d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd0be6
 
 
fc713d0
 
 
dd43c02
 
 
 
fc713d0
 
 
9bd0be6
fc713d0
 
 
9bd0be6
fc713d0
 
9bd0be6
 
 
 
 
fc713d0
 
 
9bd0be6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI
from pydantic import BaseModel
import torch, re, asyncio, aiohttp, os
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = os.getenv("MODEL_ID", "ai-forever/mGPT-1.3B-persian")
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# کم‌مصرف روی CPU
torch.set_num_threads(1)

app = FastAPI()

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    low_cpu_mem_usage=True
).to(device).eval()

class Req(BaseModel):
    prompt: str
    max_tokens: int = 160
    system: str = "تو یه دستیار فارسی خودمونی و سریع هستی؛ جواب‌ها کوتاه، رک و بامزه (۱–۲ جمله)."
    temperature: float = 0.65

@app.get("/health")
def health():
    return {"ok": True}

@app.get("/")
def root():
    return {"ok": True, "use": "POST /generate"}

def _clean(txt: str) -> str:
    txt = txt.replace("[دستیار]:", "").replace("[سیستم]:", "").replace("[کاربر]:", "")
    txt = re.sub(r"\[[^\]\n]{0,12}\]:", "", txt).strip()
    parts = re.split(r"(?<=[.!؟?])\s+", txt)
    short = " ".join(parts[:2]).strip() or txt
    return short[:220]

@app.post("/generate")
def generate(r: Req):
    sys = (r.system or "")[:400]
    user = r.prompt[:900]
    text_in = f"[سیستم]: {sys}\n[کاربر]: {user}\n[دستیار]:"
    inputs = tok(text_in, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=min(200, r.max_tokens),
            do_sample=True,
            temperature=r.temperature,
            top_p=0.9,
            repetition_penalty=1.12,
            eos_token_id=tok.eos_token_id or tok.pad_token_id,
            pad_token_id=tok.eos_token_id or tok.pad_token_id,
        )
    raw = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return {"text": _clean(raw)}

# Keep-alive داخلی (برای بیدار ماندن Space)
async def _keepalive():
    await asyncio.sleep(5)
    async with aiohttp.ClientSession() as s:
        while True:
            try:
                await s.get("http://127.0.0.1:7860/health", timeout=5)
            except Exception:
                pass
            await asyncio.sleep(300)

@app.on_event("startup")
async def _on_startup():
    asyncio.create_task(_keepalive())