rahmanansah commited on
Commit
0a4b796
·
verified ·
1 Parent(s): a74b68b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -109
app.py CHANGED
@@ -1,129 +1,62 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
4
- from fastapi.middleware.cors import CORSMiddleware
5
  import torch
6
- import uvicorn
7
 
8
- app = FastAPI(title="Bugis Indonesia API", version="1.0.0")
 
 
 
 
9
 
10
- # ----------------------------
11
- # Load model Indonesia → Bugis
12
- # ----------------------------
13
- # Pakai nama repo yang kamu sebutkan
14
- model_in2bg_name = "rahmanansah/t5-id-bugis"
15
- tokenizer_in2bg = AutoTokenizer.from_pretrained(model_in2bg_name)
16
- model_in2bg = AutoModelForSeq2SeqLM.from_pretrained(model_in2bg_name)
17
 
18
- # ----------------------------
19
- # Load model Bugis → Indonesia
20
- # ----------------------------
21
- model_bg2id_name = "rahmanansah/t5-bugis-id"
22
- tokenizer_bg2id = AutoTokenizer.from_pretrained(model_bg2id_name)
23
- model_bg2id = AutoModelForSeq2SeqLM.from_pretrained(model_bg2id_name)
24
 
25
- # ----------------------------
26
- # Load model Chat (Qwen)
27
- # ----------------------------
28
- model_qwen_name = "Qwen/Qwen2.5-1.5B-Instruct"
29
- tokenizer_qwen = AutoTokenizer.from_pretrained(model_qwen_name)
30
- # dtype="auto" + device_map="auto" agar aman di CPU/GPU
31
- model_qwen = AutoModelForCausalLM.from_pretrained(
32
- model_qwen_name,
33
- torch_dtype="auto",
34
- device_map="auto"
35
- )
36
 
37
- # ----------------------------
38
- # Request / Response Models
39
- # ----------------------------
40
- class TranslateRequest(BaseModel):
41
- text: str
42
- model: str # "in2bg" atau "bg2id"
43
-
44
- class TranslateResponse(BaseModel):
45
- result: str
46
 
47
- class ChatRequest(BaseModel):
48
- message: str
49
 
50
- class ChatResponse(BaseModel):
51
- reply: str
 
52
 
53
- # ----------------------------
54
- # Health & root
55
- # ----------------------------
56
- @app.get("/")
57
- def root():
58
- return {"ok": True, "endpoints": ["/health", "/translate", "/chat"]}
59
 
60
- @app.get("/health")
61
- def health():
62
- return {"ok": True}
63
 
64
- # ----------------------------
65
- # Translate Endpoint
66
- # ----------------------------
67
- @app.post("/translate", response_model=TranslateResponse)
68
- def translate(req: TranslateRequest):
69
- text = (req.text or "").strip()
70
- if not text:
71
  return {"result": ""}
72
 
73
- if req.model == "in2bg":
74
- tokenizer, model = tokenizer_in2bg, model_in2bg
75
- elif req.model == "bg2id":
76
- tokenizer, model = tokenizer_bg2id, model_bg2id
77
- else:
78
- return {"result": f"Model '{req.model}' tidak dikenali"}
79
 
80
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
81
- # pindahkan ke device model (aman kalau GPU)
82
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
83
-
84
- with torch.no_grad():
85
- outputs = model.generate(
86
- **inputs,
87
- max_length=128,
88
- num_beams=4, # sedikit improve kualitas
89
- early_stopping=True
90
- )
91
-
92
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
- return {"result": result}
94
-
95
- # ----------------------------
96
- # Chat Endpoint
97
- # ----------------------------
98
- @app.post("/chat", response_model=ChatResponse)
99
- def chat(req: ChatRequest):
100
- user_msg = (req.message or "").strip()
101
- if not user_msg:
102
- return {"reply": ""}
103
-
104
- # prompt sederhana & konsisten
105
- prompt = f"User: {user_msg}\nAssistant:"
106
- inputs = tokenizer_qwen(prompt, return_tensors="pt")
107
- # ke device model qwen
108
- inputs = {k: v.to(model_qwen.device) for k, v in inputs.items()}
109
 
110
- with torch.no_grad():
111
- outputs = model_qwen.generate(
112
- **inputs,
113
- max_new_tokens=200,
114
- temperature=0.7,
115
- top_p=0.9,
116
- do_sample=True
117
- )
118
 
119
- full = tokenizer_qwen.decode(outputs[0], skip_special_tokens=True)
120
- # buang prompt agar balasan bersih
121
- reply = full.replace(prompt, "").strip()
122
- return {"reply": reply}
123
 
124
- # ----------------------------
125
- # Run local (opsional)
126
- # ----------------------------
127
  if __name__ == "__main__":
128
- # Untuk test lokal. Di Spaces, launcher akan pakai objek `app` langsung.
129
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
4
  import torch
 
5
 
6
+ # Daftar model yang dipakai
7
+ MODELS = {
8
+ "in2bg": "rahmanansah/t5-id-bugis",
9
+ "bg2id": "rahmanansah/t5-bugis-id"
10
+ }
11
 
12
+ # Simpan tokenizer & model yang sudah diload
13
+ loaded_models = {}
 
 
 
 
 
14
 
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
16
 
17
+ def load_model(model_id):
18
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
20
+ return tokenizer, model
 
 
 
 
 
 
 
21
 
22
+ # Preload semua model
23
+ for key, model_id in MODELS.items():
24
+ print(f"🔄 Loading {key} -> {model_id}")
25
+ loaded_models[key] = load_model(model_id)
26
+ print("✅ Semua model sudah diload")
 
 
 
 
27
 
28
+ app = FastAPI()
 
29
 
30
+ class InputText(BaseModel):
31
+ text: str
32
+ model: str # "in2bg" atau "bg2id"
33
 
34
+ @app.post("/translate")
35
+ def translate(input: InputText):
36
+ if input.model not in loaded_models:
37
+ return {"error": f"Model '{input.model}' tidak tersedia. Pilihan: {list(loaded_models.keys())}"}
 
 
38
 
39
+ tokenizer, model = loaded_models[input.model]
 
 
40
 
41
+ if not input.text.strip():
 
 
 
 
 
 
42
  return {"result": ""}
43
 
44
+ text = input.text.strip()
 
 
 
 
 
45
 
46
+ # Tambahkan prefix sesuai arah model
47
+ if input.model == "in2bg":
48
+ prefixed_text = f"translate id2bg: {text}"
49
+ elif input.model == "bg2id":
50
+ prefixed_text = f"translate bg2id: {text}"
51
+ else:
52
+ prefixed_text = text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ inputs = tokenizer(prefixed_text, return_tensors="pt").to(device)
55
+ outputs = model.generate(**inputs, max_length=64)
56
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
57
 
58
+ return {"result": decoded}
 
 
 
59
 
 
 
 
60
  if __name__ == "__main__":
61
+ import uvicorn
62
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)