from typing import List, Union from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer from optimum.onnxruntime import ORTModelForSequenceClassification import torch MODEL_DIR = "onnx_model" CLASS_LABELS = [ "الطويل", "البسيط", "الكامل", "الوافر", "الهزج", "الرجز", "الرمل", "السريع", "المنسرح", "الخفيف", "المضارع", "المقتضب", "المجتث", "المتقارب", "المحدث" ] app = FastAPI(title="Arabic Poetry Meter Predictor (ONNX)") tokenizer = None model = None id2label = None class PredictRequest(BaseModel): sentences: Union[str, List[str]] @app.on_event("startup") def startup_event(): global tokenizer, model, id2label tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = ORTModelForSequenceClassification.from_pretrained(MODEL_DIR) config_labels = getattr(model.config, "id2label", None) if isinstance(config_labels, dict) and config_labels: fixed = {} for k, v in config_labels.items(): try: fixed[int(k)] = v except Exception: fixed[k] = v id2label = fixed else: id2label = {i: label for i, label in enumerate(CLASS_LABELS)} def normalize_sentences(value: Union[str, List[str]]) -> List[str]: if isinstance(value, str): value = value.strip() return [value] if value else [] if isinstance(value, list): return [str(v).strip() for v in value if str(v).strip()] return [] @app.get("/") def root(): return {"message": "Arabic Poetry Meter Predictor API is running"} @app.get("/health") def health(): return {"status": "ok"} @app.get("/test") def test(): return {"arabic": "مرحبا هذا اختبار"} @app.post("/predict") def predict(req: PredictRequest): if tokenizer is None or model is None: raise HTTPException(status_code=500, detail="Model is not loaded") sentences = normalize_sentences(req.sentences) if not sentences: raise HTTPException(status_code=400, detail="No poetry lines provided") try: inputs = tokenizer( sentences, return_tensors="pt", padding=True, truncation=True, max_length=256, ) outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) top_scores, top_indices = torch.max(probs, dim=-1) except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") results = [] for line, score, idx in zip(sentences, top_scores.tolist(), top_indices.tolist()): label = id2label.get(int(idx), f"LABEL_{idx}") results.append({ "line": line, "predictions": [ { "rank": 1, "label": label, "score": float(score), } ], }) return results