AraReviews-API / main.py
dralsarrani's picture
Update main.py
e93a7d4 verified
"""
AraReview β€” FastAPI Backend
Serves the fine-tuned AraBERT model via REST API.
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline
import torch
import os
# ─── APP ──────────────────────────────────────────────────────────────────────
app = FastAPI(
title="AraReview API",
description="Arabic sentiment analysis for product and hotel reviews",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ─── LOAD MODEL ───────────────────────────────────────────────────────────────
HF_MODEL = os.getenv("HF_MODEL", "dralsarrani/arareview")
DEVICE = 0 if torch.cuda.is_available() else -1
print(f"Loading model: {HF_MODEL}")
classifier = pipeline(
"text-classification",
model=HF_MODEL,
tokenizer=HF_MODEL,
device=DEVICE,
truncation=True,
max_length=128,
)
print("Model loaded and ready.")
# ─── SCHEMAS ──────────────────────────────────────────────────────────────────
class ReviewRequest(BaseModel):
text: str
class ReviewResponse(BaseModel):
text: str
label: str
confidence: float
emoji: str
# ─── ENDPOINTS ────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {"message": "AraReview API is running", "model": HF_MODEL}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=ReviewResponse)
def predict(request: ReviewRequest):
text = request.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(text) < 3:
raise HTTPException(status_code=400, detail="Text too short")
result = classifier(text)[0]
label = result["label"]
confidence = round(result["score"], 4)
emoji = "βœ… Ψ₯يجابي" if label == "positive" else "❌ Ψ³Ω„Ψ¨ΩŠ"
return ReviewResponse(
text=text,
label=label,
confidence=confidence,
emoji=emoji,
)
@app.post("/predict/batch")
def predict_batch(reviews: list[ReviewRequest]):
if len(reviews) > 50:
raise HTTPException(status_code=400, detail="Max 50 reviews per batch")
texts = [r.text.strip() for r in reviews]
results = classifier(texts)
return [
{
"text": text,
"label": r["label"],
"confidence": round(r["score"], 4),
"emoji": "βœ… Ψ₯يجابي" if r["label"] == "positive" else "❌ Ψ³Ω„Ψ¨ΩŠ",
}
for text, r in zip(texts, results)
]