rahmanansah commited on
Commit
303eae3
·
verified ·
1 Parent(s): 6bab504

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -13
app.py CHANGED
@@ -1,31 +1,38 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
 
4
  import torch
5
  import uvicorn
6
 
7
- app = FastAPI()
8
 
9
  # ----------------------------
10
  # Load model Indonesia → Bugis
11
  # ----------------------------
12
- model_in2bg_name = "rahmanansah/in2bg" # ganti sesuai repo kamu
 
13
  tokenizer_in2bg = AutoTokenizer.from_pretrained(model_in2bg_name)
14
  model_in2bg = AutoModelForSeq2SeqLM.from_pretrained(model_in2bg_name)
15
 
16
  # ----------------------------
17
  # Load model Bugis → Indonesia
18
  # ----------------------------
19
- model_bg2id_name = "rahmanansah/bg2id" # ganti sesuai repo kamu
20
  tokenizer_bg2id = AutoTokenizer.from_pretrained(model_bg2id_name)
21
  model_bg2id = AutoModelForSeq2SeqLM.from_pretrained(model_bg2id_name)
22
 
23
  # ----------------------------
24
- # Load model Chat Qwen
25
  # ----------------------------
26
  model_qwen_name = "Qwen/Qwen2.5-1.5B-Instruct"
27
  tokenizer_qwen = AutoTokenizer.from_pretrained(model_qwen_name)
28
- model_qwen = AutoModelForCausalLM.from_pretrained(model_qwen_name, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
29
 
30
  # ----------------------------
31
  # Request / Response Models
@@ -43,11 +50,26 @@ class ChatRequest(BaseModel):
43
  class ChatResponse(BaseModel):
44
  reply: str
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ----------------------------
47
  # Translate Endpoint
48
  # ----------------------------
49
  @app.post("/translate", response_model=TranslateResponse)
50
  def translate(req: TranslateRequest):
 
 
 
 
51
  if req.model == "in2bg":
52
  tokenizer, model = tokenizer_in2bg, model_in2bg
53
  elif req.model == "bg2id":
@@ -55,9 +77,18 @@ def translate(req: TranslateRequest):
55
  else:
56
  return {"result": f"Model '{req.model}' tidak dikenali"}
57
 
58
- inputs = tokenizer(req.text, return_tensors="pt", padding=True, truncation=True)
 
 
 
59
  with torch.no_grad():
60
- outputs = model.generate(**inputs, max_length=128)
 
 
 
 
 
 
61
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
  return {"result": result}
63
 
@@ -66,8 +97,15 @@ def translate(req: TranslateRequest):
66
  # ----------------------------
67
  @app.post("/chat", response_model=ChatResponse)
68
  def chat(req: ChatRequest):
69
- prompt = f"User: {req.message}\nAssistant:"
70
- inputs = tokenizer_qwen(prompt, return_tensors="pt").to(model_qwen.device)
 
 
 
 
 
 
 
71
 
72
  with torch.no_grad():
73
  outputs = model_qwen.generate(
@@ -78,13 +116,14 @@ def chat(req: ChatRequest):
78
  do_sample=True
79
  )
80
 
81
- reply = tokenizer_qwen.decode(outputs[0], skip_special_tokens=True)
82
- # hapus prompt biar hasil lebih bersih
83
- reply = reply.replace(prompt, "").strip()
84
  return {"reply": reply}
85
 
86
  # ----------------------------
87
- # Run Local (kalau di test manual)
88
  # ----------------------------
89
  if __name__ == "__main__":
 
90
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  import torch
6
  import uvicorn
7
 
8
+ app = FastAPI(title="Bugis ↔ Indonesia API", version="1.0.0")
9
 
10
  # ----------------------------
11
  # Load model Indonesia → Bugis
12
  # ----------------------------
13
+ # Pakai nama repo yang kamu sebutkan
14
+ model_in2bg_name = "rahmanansah/t5-id-bugis"
15
  tokenizer_in2bg = AutoTokenizer.from_pretrained(model_in2bg_name)
16
  model_in2bg = AutoModelForSeq2SeqLM.from_pretrained(model_in2bg_name)
17
 
18
  # ----------------------------
19
  # Load model Bugis → Indonesia
20
  # ----------------------------
21
+ model_bg2id_name = "rahmanansah/t5-bugis-id"
22
  tokenizer_bg2id = AutoTokenizer.from_pretrained(model_bg2id_name)
23
  model_bg2id = AutoModelForSeq2SeqLM.from_pretrained(model_bg2id_name)
24
 
25
  # ----------------------------
26
+ # Load model Chat (Qwen)
27
  # ----------------------------
28
  model_qwen_name = "Qwen/Qwen2.5-1.5B-Instruct"
29
  tokenizer_qwen = AutoTokenizer.from_pretrained(model_qwen_name)
30
+ # dtype="auto" + device_map="auto" agar aman di CPU/GPU
31
+ model_qwen = AutoModelForCausalLM.from_pretrained(
32
+ model_qwen_name,
33
+ torch_dtype="auto",
34
+ device_map="auto"
35
+ )
36
 
37
  # ----------------------------
38
  # Request / Response Models
 
50
  class ChatResponse(BaseModel):
51
  reply: str
52
 
53
+ # ----------------------------
54
+ # Health & root
55
+ # ----------------------------
56
+ @app.get("/")
57
+ def root():
58
+ return {"ok": True, "endpoints": ["/health", "/translate", "/chat"]}
59
+
60
+ @app.get("/health")
61
+ def health():
62
+ return {"ok": True}
63
+
64
  # ----------------------------
65
  # Translate Endpoint
66
  # ----------------------------
67
  @app.post("/translate", response_model=TranslateResponse)
68
  def translate(req: TranslateRequest):
69
+ text = (req.text or "").strip()
70
+ if not text:
71
+ return {"result": ""}
72
+
73
  if req.model == "in2bg":
74
  tokenizer, model = tokenizer_in2bg, model_in2bg
75
  elif req.model == "bg2id":
 
77
  else:
78
  return {"result": f"Model '{req.model}' tidak dikenali"}
79
 
80
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
81
+ # pindahkan ke device model (aman kalau GPU)
82
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
83
+
84
  with torch.no_grad():
85
+ outputs = model.generate(
86
+ **inputs,
87
+ max_length=128,
88
+ num_beams=4, # sedikit improve kualitas
89
+ early_stopping=True
90
+ )
91
+
92
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
  return {"result": result}
94
 
 
97
  # ----------------------------
98
  @app.post("/chat", response_model=ChatResponse)
99
  def chat(req: ChatRequest):
100
+ user_msg = (req.message or "").strip()
101
+ if not user_msg:
102
+ return {"reply": ""}
103
+
104
+ # prompt sederhana & konsisten
105
+ prompt = f"User: {user_msg}\nAssistant:"
106
+ inputs = tokenizer_qwen(prompt, return_tensors="pt")
107
+ # ke device model qwen
108
+ inputs = {k: v.to(model_qwen.device) for k, v in inputs.items()}
109
 
110
  with torch.no_grad():
111
  outputs = model_qwen.generate(
 
116
  do_sample=True
117
  )
118
 
119
+ full = tokenizer_qwen.decode(outputs[0], skip_special_tokens=True)
120
+ # buang prompt agar balasan bersih
121
+ reply = full.replace(prompt, "").strip()
122
  return {"reply": reply}
123
 
124
  # ----------------------------
125
+ # Run local (opsional)
126
  # ----------------------------
127
  if __name__ == "__main__":
128
+ # Untuk test lokal. Di Spaces, launcher akan pakai objek `app` langsung.
129
  uvicorn.run(app, host="0.0.0.0", port=7860)