from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch # Daftar model yang dipakai MODELS = { "in2bg": "rahmanansah/t5-id-bugis", "bg2id": "rahmanansah/t5-bugis-id" } # Simpan tokenizer & model yang sudah diload loaded_models = {} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(model_id): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device) return tokenizer, model # Preload semua model for key, model_id in MODELS.items(): print(f"🔄 Loading {key} -> {model_id}") loaded_models[key] = load_model(model_id) print("✅ Semua model sudah diload") app = FastAPI() class InputText(BaseModel): text: str model: str # "in2bg" atau "bg2id" @app.post("/translate") def translate(input: InputText): if input.model not in loaded_models: return {"error": f"Model '{input.model}' tidak tersedia. Pilihan: {list(loaded_models.keys())}"} tokenizer, model = loaded_models[input.model] if not input.text.strip(): return {"result": ""} text = input.text.strip() # Tambahkan prefix sesuai arah model if input.model == "in2bg": prefixed_text = f"translate id2bg: {text}" elif input.model == "bg2id": prefixed_text = f"translate bg2id: {text}" else: prefixed_text = text inputs = tokenizer(prefixed_text, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_length=64) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"result": decoded} if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)