botInfinity commited on
Commit
0660004
·
verified ·
1 Parent(s): aae3d84

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +19 -0
  2. complete_model_model.h5 +3 -0
  3. main.py +82 -0
  4. requirements.txt +56 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile
2
+ FROM python:3.11-slim
3
+
4
+ WORKDIR /app
5
+
6
+ # Install system deps
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ build-essential \
9
+ libsndfile1 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ COPY . .
16
+
17
+ EXPOSE 7860
18
+
19
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
complete_model_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2c1ab8e0fb0839b2447c57d6b9c3ae219a052bc3aceaafc3b176e899dcfe56f
3
+ size 27908672
main.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import logging
4
+ from typing import Tuple
5
+
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ from PIL import Image
10
+ import numpy as np
11
+ import tensorflow as tf
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger("vehicle-predictor")
15
+
16
+ MODEL_FILENAME = "complete_model_model.h5"
17
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), MODEL_FILENAME)
18
+ IMG_SIZE = (224, 224)
19
+
20
+ CLASS_NAMES = [
21
+ 'Ambulance', 'Bicycle', 'Boat', 'Bus', 'Car', 'Helicopter', 'Limousine',
22
+ 'Motorcycle', 'PickUp', 'Segway', 'Snowmobile', 'Tank', 'Taxi', 'Truck', 'Van'
23
+ ]
24
+
25
+ app = FastAPI(title="Vehicle Type Predictor")
26
+
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # you can tighten this later if needed
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Load model at startup
36
+ try:
37
+ logger.info("🚀 Loading model...")
38
+ model = tf.keras.models.load_model(MODEL_PATH)
39
+ logger.info("✅ Model loaded successfully.")
40
+ except Exception as e:
41
+ logger.exception("❌ Model load failed")
42
+ model = None
43
+
44
+
45
+ class PredictionResponse(BaseModel):
46
+ label: str
47
+ confidence: float
48
+
49
+
50
+ def preprocess_image_file(file_bytes: bytes) -> np.ndarray:
51
+ img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
52
+ img = img.resize(IMG_SIZE)
53
+ arr = np.asarray(img).astype("float32") / 255.0
54
+ arr = np.expand_dims(arr, axis=0)
55
+ return arr
56
+
57
+
58
+ @app.post("/predict", response_model=PredictionResponse)
59
+ async def predict(file: UploadFile = File(...)):
60
+ if model is None:
61
+ raise HTTPException(status_code=503, detail="Model not loaded")
62
+
63
+ if not file.content_type.startswith("image/"):
64
+ raise HTTPException(status_code=400, detail="File must be an image")
65
+
66
+ try:
67
+ contents = await file.read()
68
+ x = preprocess_image_file(contents)
69
+ preds = model.predict(x)
70
+ idx = int(np.argmax(preds[0]))
71
+ label = CLASS_NAMES[idx]
72
+ confidence = float(preds[0][idx])
73
+ logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}")
74
+ return PredictionResponse(label=label, confidence=confidence)
75
+ except Exception as e:
76
+ logger.exception("Prediction failed")
77
+ raise HTTPException(status_code=500, detail="Prediction failed")
78
+
79
+
80
+ @app.get("/health")
81
+ def health():
82
+ return {"status": "ok", "model_loaded": model is not None}
requirements.txt ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ annotated-types==0.7.0
3
+ anyio==4.11.0
4
+ astunparse==1.6.3
5
+ certifi==2025.10.5
6
+ charset-normalizer==3.4.4
7
+ click==8.3.0
8
+ fastapi==0.119.0
9
+ flatbuffers==25.9.23
10
+ gast==0.6.0
11
+ google-pasta==0.2.0
12
+ grpcio==1.75.1
13
+ h11==0.16.0
14
+ h5py==3.15.0
15
+ httptools==0.7.1
16
+ idna==3.11
17
+ keras==3.11.3
18
+ libclang==18.1.1
19
+ Markdown==3.9
20
+ markdown-it-py==4.0.0
21
+ MarkupSafe==3.0.3
22
+ mdurl==0.1.2
23
+ ml_dtypes==0.5.3
24
+ namex==0.1.0
25
+ numpy==2.3.3
26
+ opt_einsum==3.4.0
27
+ optree==0.17.0
28
+ packaging==25.0
29
+ pillow==11.3.0
30
+ protobuf==6.32.1
31
+ pydantic==2.12.2
32
+ pydantic_core==2.41.4
33
+ Pygments==2.19.2
34
+ python-dotenv==1.1.1
35
+ python-multipart==0.0.20
36
+ PyYAML==6.0.3
37
+ requests==2.32.5
38
+ rich==14.2.0
39
+ setuptools==80.9.0
40
+ six==1.17.0
41
+ sniffio==1.3.1
42
+ starlette==0.48.0
43
+ tensorboard==2.20.0
44
+ tensorboard-data-server==0.7.2
45
+ tensorflow==2.20.0
46
+ termcolor==3.1.0
47
+ typing-inspection==0.4.2
48
+ typing_extensions==4.15.0
49
+ urllib3==2.5.0
50
+ uvicorn==0.37.0
51
+ uvloop==0.21.0
52
+ watchfiles==1.1.1
53
+ websockets==15.0.1
54
+ Werkzeug==3.1.3
55
+ wheel==0.45.1
56
+ wrapt==1.17.3