File size: 1,929 Bytes
934ccda
28353db
 
934ccda
 
 
 
 
28353db
 
934ccda
 
 
 
 
 
 
 
6d5ca23
 
 
 
934ccda
 
 
6d5ca23
 
28353db
 
 
6d5ca23
28353db
accb6e8
 
6d5ca23
accb6e8
28353db
 
934ccda
6d5ca23
934ccda
 
 
 
 
 
 
 
 
6d5ca23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging
import sys

logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger(__name__)

app = FastAPI()
tokenizer = None
model = None

@app.on_event("startup")
async def load_model():
    global tokenizer, model
    try:
        logger.info("Loading tokenizer and model...")
        # Load from Hub (allows download on first run). 
        # If you uploaded files to the Space, change repo_id to "."
        tokenizer = AutoTokenizer.from_pretrained("offiongbassey/efik-mt")
        model = AutoModelForSeq2SeqLM.from_pretrained("offiongbassey/efik-mt")
        logger.info("βœ… Model loaded successfully!")
    except Exception as e:
        logger.error(f"❌ Failed to load model: {e}", exc_info=True)
        # A failing model load is critical. You may want to raise here to fail fast.
        # For now, we let it be, and the /translate endpoint will check.

class TranslateRequest(BaseModel):
    text: str
    source: str

@app.get("/")
async def home():
    return {"message": "Efik Translation API", "model_loaded": model is not None}

@app.post("/translate")
async def translate(req: TranslateRequest):
    if tokenizer is None or model is None:
        raise HTTPException(status_code=503, detail="Model is still loading or failed to load. Please try again in a moment.")
    
    try:
        input_text = f"{req.source} {req.text}"
        inputs = tokenizer(input_text, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=128)
        translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"translation": translation}
    except Exception as e:
        logger.error(f"Translation error: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail="Internal translation error.")