File size: 3,060 Bytes
735c7d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | from typing import List, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSequenceClassification
import torch
MODEL_DIR = "onnx_model"
CLASS_LABELS = [
"الطويل", "البسيط", "الكامل", "الوافر", "الهزج",
"الرجز", "الرمل", "السريع", "المنسرح", "الخفيف",
"المضارع", "المقتضب", "المجتث", "المتقارب", "المحدث"
]
app = FastAPI(title="Arabic Poetry Meter Predictor (ONNX)")
tokenizer = None
model = None
id2label = None
class PredictRequest(BaseModel):
sentences: Union[str, List[str]]
@app.on_event("startup")
def startup_event():
global tokenizer, model, id2label
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = ORTModelForSequenceClassification.from_pretrained(MODEL_DIR)
config_labels = getattr(model.config, "id2label", None)
if isinstance(config_labels, dict) and config_labels:
fixed = {}
for k, v in config_labels.items():
try:
fixed[int(k)] = v
except Exception:
fixed[k] = v
id2label = fixed
else:
id2label = {i: label for i, label in enumerate(CLASS_LABELS)}
def normalize_sentences(value: Union[str, List[str]]) -> List[str]:
if isinstance(value, str):
value = value.strip()
return [value] if value else []
if isinstance(value, list):
return [str(v).strip() for v in value if str(v).strip()]
return []
@app.get("/")
def root():
return {"message": "Arabic Poetry Meter Predictor API is running"}
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/test")
def test():
return {"arabic": "مرحبا هذا اختبار"}
@app.post("/predict")
def predict(req: PredictRequest):
if tokenizer is None or model is None:
raise HTTPException(status_code=500, detail="Model is not loaded")
sentences = normalize_sentences(req.sentences)
if not sentences:
raise HTTPException(status_code=400, detail="No poetry lines provided")
try:
inputs = tokenizer(
sentences,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
)
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
top_scores, top_indices = torch.max(probs, dim=-1)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
results = []
for line, score, idx in zip(sentences, top_scores.tolist(), top_indices.tolist()):
label = id2label.get(int(idx), f"LABEL_{idx}")
results.append({
"line": line,
"predictions": [
{
"rank": 1,
"label": label,
"score": float(score),
}
],
})
return results |