| from fastapi import FastAPI |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
| app = FastAPI() |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| MODEL_NAME ="guymorlan/levanti_translate_en_ar" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
| class TranslationRequest(BaseModel): |
| text: str |
|
|
| app.mount("/static", StaticFiles(directory="static", html=True), name="static") |
|
|
| @app.get("/") |
| def index() -> FileResponse: |
| return FileResponse("static/index.html", media_type="text/html") |
|
|
| @app.post("/translate") |
| def translate(req: TranslationRequest): |
| inputs = tokenizer([req.text], return_tensors="pt", padding=True, truncation=True) |
| translated = model.generate(**inputs) |
| result = tokenizer.decode(translated[0], skip_special_tokens=True) |
| return {"translation": result} |
|
|
| @app.get("/healthcheck") |
| def healthcheck(): |
| |
| test_text = "Hello" |
| inputs = tokenizer([test_text], return_tensors="pt", padding=True, truncation=True) |
| translated = model.generate(**inputs) |
| result = tokenizer.decode(translated[0], skip_special_tokens=True) |
| return {"status": "ok", "test_translation": result} |
|
|