AVAT7 commited on
Commit
fc713d0
·
verified ·
1 Parent(s): f9de5c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch, re, asyncio, aiohttp, os
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+
7
+ MODEL_ID = os.getenv("MODEL_ID", "ai-forever/mGPT-1.3B-persian")
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ dtype = torch.float16 if device == "cuda" else torch.float32
10
+
11
+
12
+ # کم‌مصرف روی CPU
13
+ torch.set_num_threads(1)
14
+
15
+
16
+ app = FastAPI()
17
+
18
+
19
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ MODEL_ID,
22
+ torch_dtype=dtype,
23
+ low_cpu_mem_usage=True
24
+ ).to(device).eval()
25
+
26
+
27
+ class Req(BaseModel):
28
+ prompt: str
29
+ max_tokens: int = 160
30
+ system: str = "تو یه دستیار فارسی خودمونی و سریع هستی؛ جواب‌ها کوتاه، رک و بامزه (۱–۲ جمله)."
31
+ temperature: float = 0.65
32
+
33
+
34
+ @app.get("/health")
35
+ def health():
36
+ return {"ok": True}
37
+
38
+
39
+ @app.get("/")
40
+ def root():
41
+ return {"ok": True, "use": "POST /generate"}
42
+
43
+
44
+ def _clean(txt: str) -> str:
45
+ txt = txt.replace("[دستیار]:", "").replace("[سیستم]:", "").replace("[کاربر]:", "")
46
+ txt = re.sub(r"\[[^\]\n]{0,12}\]:", "", txt).strip()
47
+ parts = re.split(r"(?<=[.!؟?])\s+", txt)
48
+ short = " ".join(parts[:2]).strip() or txt
49
+ return short[:220]
50
+
51
+
52
+ @app.post("/generate")
53
+ def generate(r: Req):
54
+ sys = (r.system or "")[:400]
55
+ user = r.prompt[:900]
56
+ text_in = f"[سیستم]: {sys}\n[کاربر]: {user}\n[دستیار]:"
57
+ inputs = tok(text_in, return_tensors="pt").to(device)
58
+ with torch.no_grad():
59
+ out = model.generate(
60
+ **inputs,
61
+ max_new_tokens=min(200, r.max_tokens),
62
+ do_sample=True,
63
+ asyncio.create_task(_keepalive())