File size: 3,060 Bytes
735c7d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import List, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSequenceClassification
import torch

MODEL_DIR = "onnx_model"

CLASS_LABELS = [
    "الطويل", "البسيط", "الكامل", "الوافر", "الهزج",
    "الرجز", "الرمل", "السريع", "المنسرح", "الخفيف",
    "المضارع", "المقتضب", "المجتث", "المتقارب", "المحدث"
]

app = FastAPI(title="Arabic Poetry Meter Predictor (ONNX)")

tokenizer = None
model = None
id2label = None


class PredictRequest(BaseModel):
    sentences: Union[str, List[str]]


@app.on_event("startup")
def startup_event():
    global tokenizer, model, id2label
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    model = ORTModelForSequenceClassification.from_pretrained(MODEL_DIR)

    config_labels = getattr(model.config, "id2label", None)
    if isinstance(config_labels, dict) and config_labels:
        fixed = {}
        for k, v in config_labels.items():
            try:
                fixed[int(k)] = v
            except Exception:
                fixed[k] = v
        id2label = fixed
    else:
        id2label = {i: label for i, label in enumerate(CLASS_LABELS)}


def normalize_sentences(value: Union[str, List[str]]) -> List[str]:
    if isinstance(value, str):
        value = value.strip()
        return [value] if value else []

    if isinstance(value, list):
        return [str(v).strip() for v in value if str(v).strip()]

    return []


@app.get("/")
def root():
    return {"message": "Arabic Poetry Meter Predictor API is running"}


@app.get("/health")
def health():
    return {"status": "ok"}



@app.get("/test")
def test():
    return {"arabic": "مرحبا هذا اختبار"}


@app.post("/predict")
def predict(req: PredictRequest):
    if tokenizer is None or model is None:
        raise HTTPException(status_code=500, detail="Model is not loaded")

    sentences = normalize_sentences(req.sentences)
    if not sentences:
        raise HTTPException(status_code=400, detail="No poetry lines provided")

    try:
        inputs = tokenizer(
            sentences,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=256,
        )

        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)
        top_scores, top_indices = torch.max(probs, dim=-1)

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")

    results = []
    for line, score, idx in zip(sentences, top_scores.tolist(), top_indices.tolist()):
        label = id2label.get(int(idx), f"LABEL_{idx}")
        results.append({
            "line": line,
            "predictions": [
                {
                    "rank": 1,
                    "label": label,
                    "score": float(score),
                }
            ],
        })

    return results