Digambar29 commited on
Commit
87cc891
·
1 Parent(s): ed44705

Production stage 1

Browse files
model/Image.jpg ADDED
model/__pycache__/api.cpython-312.pyc ADDED
Binary file (3.76 kB). View file
 
model/__pycache__/inference.cpython-312.pyc ADDED
Binary file (2.34 kB). View file
 
model/api.py CHANGED
@@ -1,7 +1,7 @@
1
  from slowapi import Limiter
2
  from slowapi.util import get_remote_address
3
  from slowapi.errors import RateLimitExceeded
4
- from slowapi.responses import JSONResponse
5
  from starlette.requests import Request
6
  from fastapi import FastAPI, UploadFile, File
7
  from fastapi.middleware.cors import CORSMiddleware
@@ -13,6 +13,12 @@ from model.inference import predict
13
 
14
  app = FastAPI()
15
 
 
 
 
 
 
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
@@ -21,8 +27,6 @@ app.add_middleware(
21
  allow_headers=["*"],
22
  )
23
 
24
-
25
-
26
  @app.get("/")
27
  def root():
28
  return {"status": "API running"}
@@ -46,5 +50,50 @@ def rate_limit_handler(request: Request, exc: RateLimitExceeded):
46
 
47
  @app.post("/api/predict")
48
  @limiter.limit("5/minute")
49
- async def predict_emotion(file: UploadFile = File(...)):
50
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from slowapi import Limiter
2
  from slowapi.util import get_remote_address
3
  from slowapi.errors import RateLimitExceeded
4
+ from fastapi.responses import JSONResponse
5
  from starlette.requests import Request
6
  from fastapi import FastAPI, UploadFile, File
7
  from fastapi.middleware.cors import CORSMiddleware
 
13
 
14
  app = FastAPI()
15
 
16
+ CURRENT_STATE = {
17
+ "emotion": None,
18
+ "confidence": 0.0
19
+ }
20
+
21
+
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"],
 
27
  allow_headers=["*"],
28
  )
29
 
 
 
30
  @app.get("/")
31
  def root():
32
  return {"status": "API running"}
 
50
 
51
  @app.post("/api/predict")
52
  @limiter.limit("5/minute")
53
+ async def predict_emotion(request: Request, file: UploadFile = File(...)):
54
+ contents = await file.read()
55
+
56
+ try:
57
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
58
+ except Exception:
59
+ return {
60
+ "state": "error",
61
+ "reason": "invalid_image"
62
+ }
63
+
64
+ result = predict(image)
65
+
66
+ if result["confidence"] >= 0.6:
67
+ RECENT_PREDICTIONS.append(result["emotion"])
68
+ if len(RECENT_PREDICTIONS) > WINDOW_SIZE:
69
+ RECENT_PREDICTIONS.pop(0)
70
+
71
+
72
+ if result["confidence"] < 0.6:
73
+ return {
74
+ "state": "uncertain",
75
+ "emotion": CURRENT_STATE["emotion"],
76
+ "confidence": result["confidence"],
77
+ "is_confident": False
78
+ }
79
+
80
+ # update memory
81
+ if RECENT_PREDICTIONS:
82
+ dominant_emotion = max(
83
+ set(RECENT_PREDICTIONS),
84
+ key=RECENT_PREDICTIONS.count
85
+ )
86
+
87
+ CURRENT_STATE["emotion"] = dominant_emotion
88
+ CURRENT_STATE["confidence"] = result["confidence"]
89
+
90
+ RECENT_PREDICTIONS = []
91
+ WINDOW_SIZE = 5
92
+
93
+
94
+ return {
95
+ "state": "stable",
96
+ "emotion": CURRENT_STATE["emotion"],
97
+ "confidence": CURRENT_STATE["confidence"],
98
+ "is_confident": True
99
+ }
model/inference.py CHANGED
@@ -34,9 +34,16 @@ transform = transforms.Compose([
34
  )
35
  ])
36
 
 
37
  @torch.no_grad()
38
  def predict(pil_image: Image.Image):
39
  x = transform(pil_image).unsqueeze(0).to(device)
40
  logits = model(x)
41
- idx = logits.argmax(dim=1).item()
42
- return classes[idx]
 
 
 
 
 
 
 
34
  )
35
  ])
36
 
37
+
38
  @torch.no_grad()
39
  def predict(pil_image: Image.Image):
40
  x = transform(pil_image).unsqueeze(0).to(device)
41
  logits = model(x)
42
+
43
+ probs = torch.softmax(logits, dim=1)
44
+ conf, idx = probs.max(dim=1)
45
+
46
+ return {
47
+ "emotion": classes[idx.item()],
48
+ "confidence": conf.item()
49
+ }