File size: 3,236 Bytes
85d43b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e93a7d4
85d43b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
AraReview β€” FastAPI Backend
Serves the fine-tuned AraBERT model via REST API.
"""

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline
import torch
import os

# ─── APP ──────────────────────────────────────────────────────────────────────

app = FastAPI(
    title="AraReview API",
    description="Arabic sentiment analysis for product and hotel reviews",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ─── LOAD MODEL ───────────────────────────────────────────────────────────────

HF_MODEL = os.getenv("HF_MODEL", "dralsarrani/arareview")
DEVICE   = 0 if torch.cuda.is_available() else -1

print(f"Loading model: {HF_MODEL}")
classifier = pipeline(
    "text-classification",
    model=HF_MODEL,
    tokenizer=HF_MODEL,
    device=DEVICE,
    truncation=True,
    max_length=128,
)
print("Model loaded and ready.")

# ─── SCHEMAS ──────────────────────────────────────────────────────────────────

class ReviewRequest(BaseModel):
    text: str

class ReviewResponse(BaseModel):
    text: str
    label: str
    confidence: float
    emoji: str

# ─── ENDPOINTS ────────────────────────────────────────────────────────────────

@app.get("/")
def root():
    return {"message": "AraReview API is running", "model": HF_MODEL}

@app.get("/health")
def health():
    return {"status": "ok"}

@app.post("/predict", response_model=ReviewResponse)
def predict(request: ReviewRequest):
    text = request.text.strip()

    if not text:
        raise HTTPException(status_code=400, detail="Text cannot be empty")

    if len(text) < 3:
        raise HTTPException(status_code=400, detail="Text too short")

    result    = classifier(text)[0]
    label     = result["label"]
    confidence = round(result["score"], 4)
    emoji     = "βœ… Ψ₯يجابي" if label == "positive" else "❌ Ψ³Ω„Ψ¨ΩŠ"

    return ReviewResponse(
        text=text,
        label=label,
        confidence=confidence,
        emoji=emoji,
    )

@app.post("/predict/batch")
def predict_batch(reviews: list[ReviewRequest]):
    if len(reviews) > 50:
        raise HTTPException(status_code=400, detail="Max 50 reviews per batch")

    texts   = [r.text.strip() for r in reviews]
    results = classifier(texts)

    return [
        {
            "text":       text,
            "label":      r["label"],
            "confidence": round(r["score"], 4),
            "emoji":      "βœ… Ψ₯يجابي" if r["label"] == "positive" else "❌ Ψ³Ω„Ψ¨ΩŠ",
        }
        for text, r in zip(texts, results)
    ]