Spaces:
Configuration error
Configuration error
File size: 2,949 Bytes
c7cce82 ed44705 87cc891 ed44705 8c3edd8 c7cce82 a95c51f 8c3edd8 c7cce82 8c3edd8 87cc891 8c3edd8 ed44705 87cc891 c7cce82 87cc891 c7cce82 87cc891 c7cce82 87cc891 c7cce82 87cc891 c7cce82 87cc891 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | from unittest import result
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.responses import JSONResponse
from starlette.requests import Request
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from PIL import Image
import io
import time
from model.inference import predict
LAST_CONFIDENT_TS = None
DECAY_SECONDS = 2.0
app = FastAPI()
CURRENT_STATE = {
"emotion": None,
"confidence": 0.0
}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return {"status": "API running"}
@app.post("/api/predict")
async def predict_emotion(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
emotion = predict(image)
return {"emotion": emotion}
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={"detail": "Too many requests, slow down 😅"}
)
@app.post("/api/predict")
@limiter.limit("5/minute")
async def predict_emotion(request: Request, file: UploadFile = File(...)):
contents = await file.read()
try:
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception:
return {
"state": "error",
"reason": "invalid_image"
}
result = predict(image)
if result["confidence"] >= 0.6:
RECENT_PREDICTIONS.append(result["emotion"])
if len(RECENT_PREDICTIONS) > WINDOW_SIZE:
RECENT_PREDICTIONS.pop(0)
if result["confidence"] < 0.6:
return {
"state": "uncertain",
"emotion": CURRENT_STATE["emotion"],
"confidence": result["confidence"],
"is_confident": False
}
# update memory
if RECENT_PREDICTIONS:
dominant_emotion = max(
set(RECENT_PREDICTIONS),
key=RECENT_PREDICTIONS.count
)
# update memory
CURRENT_STATE["emotion"] = dominant_emotion
CURRENT_STATE["confidence"] = result["confidence"]
LAST_CONFIDENT_TS = time.time()
RECENT_PREDICTIONS = []
WINDOW_SIZE = 5
if LAST_CONFIDENT_TS is not None:
if time.time() - LAST_CONFIDENT_TS > DECAY_SECONDS:
CURRENT_STATE["emotion"] = None
CURRENT_STATE["confidence"] = 0.0
return {
"state": "stable" if CURRENT_STATE["emotion"] else "unknown",
"emotion": CURRENT_STATE["emotion"],
"confidence": CURRENT_STATE["confidence"],
"is_confident": CURRENT_STATE["emotion"] is not None
}
|