Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import uvicorn | |
| app = FastAPI( | |
| title="MAKINI API", | |
| description="Gender bias detection for Swahili and African French", | |
| version="1.0.0" | |
| ) | |
| SUPPORTED_LANGUAGES = {"sw", "fr"} | |
| LABELS = ["neutral", "stereotype", "counter-stereotype", "derogation"] | |
| MODEL_ID = "Daudipdg/makini-v1" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| class InferenceRequest(BaseModel): | |
| text: str | |
| language: str | |
| class InferenceResponse(BaseModel): | |
| label: str | |
| confidence: float | |
| scores: dict | |
| language: str | |
| def root(): | |
| return { | |
| "model": "MAKINI v1", | |
| "company": "Iroh Intelligence Labs", | |
| "contact": "david@makini.tech", | |
| "supported_languages": ["sw", "fr"], | |
| "labels": LABELS | |
| } | |
| def health(): | |
| return {"status": "ok", "model": MODEL_ID} | |
| def predict(request: InferenceRequest): | |
| if request.language not in SUPPORTED_LANGUAGES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported language. Supported: {SUPPORTED_LANGUAGES}" | |
| ) | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| inputs = tokenizer( | |
| request.text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=128 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1)[0] | |
| scores = {LABELS[i]: round(probs[i].item(), 4) for i in range(len(LABELS))} | |
| top_idx = probs.argmax().item() | |
| return InferenceResponse( | |
| label=LABELS[top_idx], | |
| confidence=round(probs[top_idx].item(), 4), | |
| scores=scores, | |
| language=request.language | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |