AVAT7 commited on
Commit
9bd0be6
·
verified ·
1 Parent(s): dd43c02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -3,61 +3,74 @@ from pydantic import BaseModel
3
  import torch, re, asyncio, aiohttp, os
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
-
7
  MODEL_ID = os.getenv("MODEL_ID", "ai-forever/mGPT-1.3B-persian")
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  dtype = torch.float16 if device == "cuda" else torch.float32
10
 
11
-
12
  # کم‌مصرف روی CPU
13
  torch.set_num_threads(1)
14
 
15
-
16
  app = FastAPI()
17
 
18
-
19
  tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
20
  model = AutoModelForCausalLM.from_pretrained(
21
- MODEL_ID,
22
- torch_dtype=dtype,
23
- low_cpu_mem_usage=True
24
  ).to(device).eval()
25
 
26
-
27
  class Req(BaseModel):
28
  prompt: str
29
  max_tokens: int = 160
30
  system: str = "تو یه دستیار فارسی خودمونی و سریع هستی؛ جواب‌ها کوتاه، رک و بامزه (۱–۲ جمله)."
31
  temperature: float = 0.65
32
 
33
-
34
  @app.get("/health")
35
  def health():
36
- return {"ok": True}
37
-
38
 
39
  @app.get("/")
40
  def root():
41
- return {"ok": True, "use": "POST /generate"}
42
-
43
 
44
  def _clean(txt: str) -> str:
45
- txt = txt.replace("[دستیار]:", "").replace("[سیستم]:", "").replace("[کاربر]:", "")
46
- txt = re.sub(r"\[[^\]\n]{0,12}\]:", "", txt).strip()
47
- parts = re.split(r"(?<=[.!؟?])\s+", txt)
48
- short = " ".join(parts[:2]).strip() or txt
49
- return short[:220]
50
-
51
 
52
  @app.post("/generate")
53
  def generate(r: Req):
54
- sys = (r.system or "")[:400]
55
- user = r.prompt[:900]
56
- text_in = f"[سیستم]: {sys}\n[کاربر]: {user}\n[دستیار]:"
57
- inputs = tok(text_in, return_tensors="pt").to(device)
58
- with torch.no_grad():
59
- out = model.generate(
60
- **inputs,
61
- max_new_tokens=min(200, r.max_tokens),
62
- do_sample=True,
63
- asyncio.create_task(_keepalive())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch, re, asyncio, aiohttp, os
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
 
6
  MODEL_ID = os.getenv("MODEL_ID", "ai-forever/mGPT-1.3B-persian")
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  dtype = torch.float16 if device == "cuda" else torch.float32
9
 
 
10
  # کم‌مصرف روی CPU
11
  torch.set_num_threads(1)
12
 
 
13
  app = FastAPI()
14
 
 
15
  tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
16
  model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL_ID,
18
+ torch_dtype=dtype,
19
+ low_cpu_mem_usage=True
20
  ).to(device).eval()
21
 
 
22
  class Req(BaseModel):
23
  prompt: str
24
  max_tokens: int = 160
25
  system: str = "تو یه دستیار فارسی خودمونی و سریع هستی؛ جواب‌ها کوتاه، رک و بامزه (۱–۲ جمله)."
26
  temperature: float = 0.65
27
 
 
28
  @app.get("/health")
29
  def health():
30
+ return {"ok": True}
 
31
 
32
  @app.get("/")
33
  def root():
34
+ return {"ok": True, "use": "POST /generate"}
 
35
 
36
  def _clean(txt: str) -> str:
37
+ txt = txt.replace("[دستیار]:", "").replace("[سیستم]:", "").replace("[کاربر]:", "")
38
+ txt = re.sub(r"\[[^\]\n]{0,12}\]:", "", txt).strip()
39
+ parts = re.split(r"(?<=[.!؟?])\s+", txt)
40
+ short = " ".join(parts[:2]).strip() or txt
41
+ return short[:220]
 
42
 
43
  @app.post("/generate")
44
  def generate(r: Req):
45
+ sys = (r.system or "")[:400]
46
+ user = r.prompt[:900]
47
+ text_in = f"[سیستم]: {sys}\n[کاربر]: {user}\n[دستیار]:"
48
+ inputs = tok(text_in, return_tensors="pt").to(device)
49
+ with torch.no_grad():
50
+ out = model.generate(
51
+ **inputs,
52
+ max_new_tokens=min(200, r.max_tokens),
53
+ do_sample=True,
54
+ temperature=r.temperature,
55
+ top_p=0.9,
56
+ repetition_penalty=1.12,
57
+ eos_token_id=tok.eos_token_id or tok.pad_token_id,
58
+ pad_token_id=tok.eos_token_id or tok.pad_token_id,
59
+ )
60
+ raw = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
61
+ return {"text": _clean(raw)}
62
+
63
+ # Keep-alive داخلی (برای بیدار ماندن Space)
64
+ async def _keepalive():
65
+ await asyncio.sleep(5)
66
+ async with aiohttp.ClientSession() as s:
67
+ while True:
68
+ try:
69
+ await s.get("http://127.0.0.1:7860/health", timeout=5)
70
+ except Exception:
71
+ pass
72
+ await asyncio.sleep(300)
73
+
74
+ @app.on_event("startup")
75
+ async def _on_startup():
76
+ asyncio.create_task(_keepalive())