Spaces:
Sleeping
Sleeping
File size: 1,785 Bytes
43e395e 0a4b796 43e395e 4599ca4 0a4b796 43e395e 0a4b796 6bab504 0a4b796 6bab504 0a4b796 6bab504 0a4b796 43e395e 0a4b796 4599ca4 0a4b796 46c1dc4 0a4b796 303eae3 0a4b796 303eae3 0a4b796 303eae3 0a4b796 6bab504 0a4b796 6bab504 0a4b796 6bab504 0a4b796 b2d120c 0a4b796 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Daftar model yang dipakai
MODELS = {
"in2bg": "rahmanansah/t5-id-bugis",
"bg2id": "rahmanansah/t5-bugis-id"
}
# Simpan tokenizer & model yang sudah diload
loaded_models = {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
return tokenizer, model
# Preload semua model
for key, model_id in MODELS.items():
print(f"🔄 Loading {key} -> {model_id}")
loaded_models[key] = load_model(model_id)
print("✅ Semua model sudah diload")
app = FastAPI()
class InputText(BaseModel):
text: str
model: str # "in2bg" atau "bg2id"
@app.post("/translate")
def translate(input: InputText):
if input.model not in loaded_models:
return {"error": f"Model '{input.model}' tidak tersedia. Pilihan: {list(loaded_models.keys())}"}
tokenizer, model = loaded_models[input.model]
if not input.text.strip():
return {"result": ""}
text = input.text.strip()
# Tambahkan prefix sesuai arah model
if input.model == "in2bg":
prefixed_text = f"translate id2bg: {text}"
elif input.model == "bg2id":
prefixed_text = f"translate bg2id: {text}"
else:
prefixed_text = text
inputs = tokenizer(prefixed_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=64)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"result": decoded}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|