| from typing import List, Union |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer |
| from optimum.onnxruntime import ORTModelForSequenceClassification |
| import torch |
|
|
| MODEL_DIR = "onnx_model" |
|
|
| CLASS_LABELS = [ |
| "الطويل", "البسيط", "الكامل", "الوافر", "الهزج", |
| "الرجز", "الرمل", "السريع", "المنسرح", "الخفيف", |
| "المضارع", "المقتضب", "المجتث", "المتقارب", "المحدث" |
| ] |
|
|
| app = FastAPI(title="Arabic Poetry Meter Predictor (ONNX)") |
|
|
| tokenizer = None |
| model = None |
| id2label = None |
|
|
|
|
| class PredictRequest(BaseModel): |
| sentences: Union[str, List[str]] |
|
|
|
|
| @app.on_event("startup") |
| def startup_event(): |
| global tokenizer, model, id2label |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
| model = ORTModelForSequenceClassification.from_pretrained(MODEL_DIR) |
|
|
| config_labels = getattr(model.config, "id2label", None) |
| if isinstance(config_labels, dict) and config_labels: |
| fixed = {} |
| for k, v in config_labels.items(): |
| try: |
| fixed[int(k)] = v |
| except Exception: |
| fixed[k] = v |
| id2label = fixed |
| else: |
| id2label = {i: label for i, label in enumerate(CLASS_LABELS)} |
|
|
|
|
| def normalize_sentences(value: Union[str, List[str]]) -> List[str]: |
| if isinstance(value, str): |
| value = value.strip() |
| return [value] if value else [] |
|
|
| if isinstance(value, list): |
| return [str(v).strip() for v in value if str(v).strip()] |
|
|
| return [] |
|
|
|
|
| @app.get("/") |
| def root(): |
| return {"message": "Arabic Poetry Meter Predictor API is running"} |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok"} |
|
|
|
|
|
|
| @app.get("/test") |
| def test(): |
| return {"arabic": "مرحبا هذا اختبار"} |
|
|
|
|
| @app.post("/predict") |
| def predict(req: PredictRequest): |
| if tokenizer is None or model is None: |
| raise HTTPException(status_code=500, detail="Model is not loaded") |
|
|
| sentences = normalize_sentences(req.sentences) |
| if not sentences: |
| raise HTTPException(status_code=400, detail="No poetry lines provided") |
|
|
| try: |
| inputs = tokenizer( |
| sentences, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=256, |
| ) |
|
|
| outputs = model(**inputs) |
| probs = torch.softmax(outputs.logits, dim=-1) |
| top_scores, top_indices = torch.max(probs, dim=-1) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
| results = [] |
| for line, score, idx in zip(sentences, top_scores.tolist(), top_indices.tolist()): |
| label = id2label.get(int(idx), f"LABEL_{idx}") |
| results.append({ |
| "line": line, |
| "predictions": [ |
| { |
| "rank": 1, |
| "label": label, |
| "score": float(score), |
| } |
| ], |
| }) |
|
|
| return results |