n0v33n commited on
Commit
c80fed4
·
1 Parent(s): 013a29f

Add FastAPI Docker app and model files

Browse files
Files changed (5) hide show
  1. Dockerfile +30 -0
  2. app.py +94 -0
  3. metadata.json +37 -0
  4. mobilenetv3_gender_weights.pth +3 -0
  5. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Prevents Python from writing pyc files
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Set working directory
13
+ WORKDIR /app
14
+
15
+ # Copy requirements first (better caching)
16
+ COPY requirements.txt .
17
+
18
+ # Install Python dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy app files
22
+ COPY app.py .
23
+ COPY mobilenetv3_gender_weights.pth .
24
+ COPY metadata.json .
25
+
26
+ # Expose HF-required port
27
+ EXPOSE 7860
28
+
29
+ # Start FastAPI server
30
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import json
5
+ import io
6
+
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+ # --------------------------------------------------
12
+ # Load model ONCE at startup
13
+ # --------------------------------------------------
14
+
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ WEIGHTS_PATH = "mobilenetv3_gender_weights.pth"
18
+ METADATA_PATH = "metadata.json"
19
+
20
+ # Load metadata
21
+ with open(METADATA_PATH, "r") as f:
22
+ metadata = json.load(f)
23
+
24
+ # Build model
25
+ model = timm.create_model(
26
+ metadata["model_name"],
27
+ pretrained=False,
28
+ num_classes=metadata["num_classes"]
29
+ )
30
+
31
+ # Rebuild classifier
32
+ config = metadata["classifier_config"]
33
+ model.classifier = nn.Sequential(
34
+ nn.Linear(config["in_features"], config["hidden_dim"]),
35
+ nn.ReLU(),
36
+ nn.Dropout(config["dropout"]),
37
+ nn.Linear(config["hidden_dim"], metadata["num_classes"])
38
+ )
39
+
40
+ # Load weights safely
41
+ state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
42
+ model.load_state_dict(state_dict)
43
+ model.to(DEVICE)
44
+ model.eval()
45
+
46
+ # Image preprocessing
47
+ transform = transforms.Compose([
48
+ transforms.Resize((224, 224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(
51
+ mean=[0.485, 0.456, 0.406],
52
+ std=[0.229, 0.224, 0.225]
53
+ )
54
+ ])
55
+
56
+ # --------------------------------------------------
57
+ # FastAPI app
58
+ # --------------------------------------------------
59
+
60
+ app = FastAPI(
61
+ title="Gender Classification API",
62
+ description="MobileNetV3 Gender Prediction",
63
+ version="1.0"
64
+ )
65
+
66
+ @app.get("/")
67
+ def root():
68
+ return {
69
+ "message": "Gender Classification API is running 🚀",
70
+ "model": metadata["model_name"],
71
+ "classes": metadata["class_names"]
72
+ }
73
+
74
+ @app.post("/predict")
75
+ async def predict(file: UploadFile = File(...)):
76
+ image_bytes = await file.read()
77
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
78
+
79
+ image_tensor = transform(image).unsqueeze(0).to(DEVICE)
80
+
81
+ with torch.no_grad():
82
+ outputs = model(image_tensor)
83
+ probs = torch.softmax(outputs, dim=1)
84
+
85
+ confidence, predicted = torch.max(probs, 1)
86
+
87
+ return {
88
+ "predicted_class": metadata["class_names"][predicted.item()],
89
+ "confidence": round(confidence.item() * 100, 2),
90
+ "probabilities": {
91
+ metadata["class_names"][0]: round(probs[0][0].item() * 100, 2),
92
+ metadata["class_names"][1]: round(probs[0][1].item() * 100, 2),
93
+ }
94
+ }
metadata.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "mobilenetv3_large_100",
3
+ "num_classes": 2,
4
+ "class_names": [
5
+ "Female",
6
+ "Male"
7
+ ],
8
+ "accuracy": 0.9285714285714286,
9
+ "confusion_matrix": [
10
+ [
11
+ 104,
12
+ 7
13
+ ],
14
+ [
15
+ 9,
16
+ 104
17
+ ]
18
+ ],
19
+ "epochs_trained": 10,
20
+ "loss_history": [
21
+ 0.6070853605352599,
22
+ 0.4011639428549799,
23
+ 0.3184526763085661,
24
+ 0.27616333987178476,
25
+ 0.24627566992722708,
26
+ 0.21110220202084246,
27
+ 0.13941147615169658,
28
+ 0.1258032820348082,
29
+ 0.13784673257634558,
30
+ 0.12205909436632847
31
+ ],
32
+ "classifier_config": {
33
+ "in_features": 1280,
34
+ "hidden_dim": 128,
35
+ "dropout": 0.4
36
+ }
37
+ }
mobilenetv3_gender_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c039486323ad7c626f511d63e854cbc632804070448a721930016608b934af1b
3
+ size 17676312
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ torchvision
5
+ timm
6
+ pillow