darija-aicc-api / main.py
MohammedMediani's picture
Fix Links: Update to MohammedMediani for consistency
933f6a1
"""
Darija NLU API - Professional REST API for Moroccan Arabic Sentiment/Intent Classification.
Powered by MARBERTv2 fine-tuned on Darija.
"""
import os
from contextlib import asynccontextmanager
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import pipeline
# --- Configuration ---
MODEL_ID = "MohammedMediani/marbert-fine-tuned-darija-aicc"
# Global pipeline variable
nlu_pipeline = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager for loading the model on startup.
This ensures the model is loaded only once.
"""
global nlu_pipeline
try:
print(f"Loading model from HuggingFace Hub: {MODEL_ID}...")
# device=0 uses GPU if available, -1 uses CPU
# We rely on transformers to auto-detect the best available device if not specified,
# but explicit integer is often safer for pipelines.
import torch
device = 0 if torch.cuda.is_available() else -1
nlu_pipeline = pipeline(
"text-classification",
model=MODEL_ID,
tokenizer=MODEL_ID,
device=device
)
print("Model loaded successfully!")
except Exception as e:
print(f"CRITICAL: Failed to load model: {e}")
nlu_pipeline = None
yield
# Cleanup if necessary
nlu_pipeline = None
# --- FastAPI App Definition ---
app = FastAPI(
title="Darija NLU API",
description="Professional API for intent classification in Moroccan Darija (Arabic Dialect).",
version="1.0.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc"
)
# --- Data Models ---
class TextInput(BaseModel):
"""Request model for text classification."""
text: str = Field(..., description="The text in Darija to analyze", min_length=1, example="3afak bghit nchouf solde")
class PredictionResponse(BaseModel):
"""Response model containing the predicted intent and confidence score."""
intent: str = Field(..., description="Predicted intent label")
confidence: float = Field(..., description="Confidence score between 0.0 and 1.0")
# --- Routes ---
@app.get("/", tags=["General"])
def read_root() -> Dict[str, str]:
"""Root endpoint returning welcome message."""
return {"message": "Welcome to the Darija NLU API. Use POST /predict to analyze text."}
@app.get("/health", tags=["General"])
def health_check() -> Dict[str, str]:
"""Health check endpoint to verify service status and model loading."""
if nlu_pipeline is None:
raise HTTPException(status_code=503, detail="Service initializing or model failed to load.")
return {"status": "ok", "model_status": "loaded"}
@app.post("/predict", response_model=PredictionResponse, tags=["Inference"])
async def predict_intent(request: TextInput) -> PredictionResponse:
"""
Predict the intent of the provided Darija text.
"""
if nlu_pipeline is None:
raise HTTPException(status_code=503, detail="Model not initialized.")
try:
# Pipeline returns a list of dicts: [{'label': 'intent_name', 'score': 0.99}]
# We assume top_k=1 by default
prediction = nlu_pipeline(request.text, top_k=1)[0]
return PredictionResponse(
intent=prediction['label'],
confidence=prediction['score']
)
except Exception as e:
# Log the error internally here
print(f"Inference error: {e}")
raise HTTPException(status_code=500, detail="Internal processing error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) # 7860 is the default port for HF Spaces