File size: 2,949 Bytes
c7cce82
ed44705
 
 
87cc891
ed44705
8c3edd8
 
 
 
 
c7cce82
a95c51f
8c3edd8
c7cce82
 
 
 
8c3edd8
 
87cc891
 
 
 
 
 
8c3edd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed44705
 
 
 
 
 
 
 
 
 
 
 
 
87cc891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7cce82
87cc891
 
c7cce82
87cc891
 
 
 
c7cce82
 
 
 
 
87cc891
 
c7cce82
87cc891
 
c7cce82
87cc891
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from unittest import result
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.responses import JSONResponse
from starlette.requests import Request
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from PIL import Image
import io
import time
from model.inference import predict

LAST_CONFIDENT_TS = None
DECAY_SECONDS = 2.0


app = FastAPI()

CURRENT_STATE = {
    "emotion": None,
    "confidence": 0.0
}


app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def root():
    return {"status": "API running"}

@app.post("/api/predict")
async def predict_emotion(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    emotion = predict(image)
    return {"emotion": emotion}

limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter

@app.exception_handler(RateLimitExceeded)
def rate_limit_handler(request: Request, exc: RateLimitExceeded):
    return JSONResponse(
        status_code=429,
        content={"detail": "Too many requests, slow down 😅"}
    )

@app.post("/api/predict")
@limiter.limit("5/minute")
async def predict_emotion(request: Request, file: UploadFile = File(...)):
    contents = await file.read()

    try:
        image = Image.open(io.BytesIO(contents)).convert("RGB")
    except Exception:
        return {
            "state": "error",
            "reason": "invalid_image"
        }

    result = predict(image)

    if result["confidence"] >= 0.6:
        RECENT_PREDICTIONS.append(result["emotion"])
        if len(RECENT_PREDICTIONS) > WINDOW_SIZE:
            RECENT_PREDICTIONS.pop(0)


    if result["confidence"] < 0.6:
        return {
            "state": "uncertain",
            "emotion": CURRENT_STATE["emotion"],
            "confidence": result["confidence"],
            "is_confident": False
        }

    # update memory
    if RECENT_PREDICTIONS:
        dominant_emotion = max(
            set(RECENT_PREDICTIONS),
            key=RECENT_PREDICTIONS.count
        )

        # update memory
        CURRENT_STATE["emotion"] = dominant_emotion
        CURRENT_STATE["confidence"] = result["confidence"]
        LAST_CONFIDENT_TS = time.time()

    RECENT_PREDICTIONS = []
    WINDOW_SIZE = 5

    if LAST_CONFIDENT_TS is not None:
        if time.time() - LAST_CONFIDENT_TS > DECAY_SECONDS:
            CURRENT_STATE["emotion"] = None
            CURRENT_STATE["confidence"] = 0.0


    return {
        "state": "stable" if CURRENT_STATE["emotion"] else "unknown",
        "emotion": CURRENT_STATE["emotion"],
        "confidence": CURRENT_STATE["confidence"],
        "is_confident": CURRENT_STATE["emotion"] is not None
    }