faststager commited on
Commit
749b427
·
verified ·
1 Parent(s): 257b53f

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +175 -0
main.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import yaml
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import os
8
+ import base64
9
+ import io
10
+ import time
11
+ import threading
12
+ from typing import List, Dict, Union, Tuple, Optional
13
+
14
+ from fastapi import FastAPI, HTTPException
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from pydantic import BaseModel
17
+
18
+ class Base64Image(BaseModel):
19
+ image_data: str
20
+
21
+ class BatchBase64Images(BaseModel):
22
+ image_data_list: List[str]
23
+ use_gpu: bool = True
24
+
25
+ CONFIG_PATH: str = os.getenv('CONFIG_PATH', 'staging_config.yaml')
26
+ CHECKPOINT_FILENAME: str = os.getenv('CHECKPOINT_PATH', 'model.pt')
27
+
28
+ model_lock: threading.Lock = threading.Lock()
29
+
30
+ def get_model(model_name: str, num_classes: int) -> nn.Module:
31
+ model: Optional[nn.Module] = None
32
+ if model_name == "efficientnet_b0":
33
+ model = models.efficientnet_b0(weights=None)
34
+ num_ftrs: int = model.classifier[1].in_features
35
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
36
+ else:
37
+ raise ValueError(f"Model '{model_name}' not supported.")
38
+ return model
39
+
40
+ def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
41
+ if not os.path.exists(checkpoint_path):
42
+ raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
43
+ checkpoint: dict = torch.load(checkpoint_path, map_location=device)
44
+ model_name: str = checkpoint['model_name']
45
+ class_to_idx: Dict[str, int] = checkpoint['class_to_idx']
46
+ model_output_size: int = 1 if len(class_to_idx) == 2 else len(class_to_idx)
47
+ model: nn.Module = get_model(model_name, num_classes=model_output_size)
48
+ model.load_state_dict(checkpoint['state_dict'])
49
+ model.to(device)
50
+ model.eval()
51
+ idx_to_class: Dict[int, str] = {v: k for k, v in class_to_idx.items()}
52
+ return model, idx_to_class
53
+
54
+ try:
55
+ with open(CONFIG_PATH, 'r') as f:
56
+ config: dict = yaml.safe_load(f)
57
+ except FileNotFoundError:
58
+ raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'. Make sure it's uploaded to the Space.")
59
+
60
+ if torch.cuda.is_available():
61
+ gpu_device: torch.device = torch.device("cuda")
62
+ gpu_model: nn.Module
63
+ IDX_TO_CLASS: Dict[int, str]
64
+ gpu_model, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, gpu_device)
65
+ print(f"GPU model loaded successfully on {gpu_device}")
66
+ else:
67
+ gpu_device: Optional[torch.device] = None
68
+ gpu_model: Optional[nn.Module] = None
69
+ print("No GPU available")
70
+
71
+ cpu_device: torch.device = torch.device("cpu")
72
+ cpu_model: nn.Module
73
+ IDX_TO_CLASS: Dict[int, str]
74
+ cpu_model, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, cpu_device)
75
+ print(f"CPU model loaded successfully")
76
+ print(f"Class mapping: {IDX_TO_CLASS}")
77
+
78
+ IMG_SIZE: int = config['data_params']['image_size']
79
+ inference_transform: transforms.Compose = transforms.Compose([
80
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
83
+ ])
84
+
85
+ def base64_to_pil(base64_str: str) -> Image.Image:
86
+ try:
87
+ if "base64," in base64_str:
88
+ base64_str = base64_str.split("base64,")[1]
89
+ image_data: bytes = base64.b64decode(base64_str)
90
+ return Image.open(io.BytesIO(image_data))
91
+ except Exception as e:
92
+ raise ValueError(f"Invalid base64 string: {e}")
93
+
94
+ def predict_batch(pil_images: List[Image.Image], use_gpu: bool) -> List[Dict[str, Union[dict, float]]]:
95
+ device: torch.device = gpu_device if (use_gpu and gpu_device) else cpu_device
96
+ model: nn.Module = gpu_model if (use_gpu and gpu_model) else cpu_model
97
+ image_tensors: List[torch.Tensor] = []
98
+ for img in pil_images:
99
+ if img.mode != "RGB":
100
+ img = img.convert("RGB")
101
+ image_tensors.append(inference_transform(img))
102
+ batch_tensor: torch.Tensor = torch.stack(image_tensors).to(device)
103
+ with model_lock, torch.no_grad():
104
+ start_time: float = time.time()
105
+ output: torch.Tensor = model(batch_tensor)
106
+ batch_time: float = time.time() - start_time
107
+ results: List[Dict[str, Union[dict, float]]] = []
108
+ probs: Union[List[float], float] = torch.sigmoid(output).squeeze().tolist()
109
+ class_0_name: str = IDX_TO_CLASS.get(0, "Class 0")
110
+ class_1_name: str = IDX_TO_CLASS.get(1, "Class 1")
111
+ if isinstance(probs, float):
112
+ probs = [probs]
113
+ for i, prob in enumerate(probs):
114
+ prediction: Dict[str, float] = {
115
+ class_0_name: 1 - prob,
116
+ class_1_name: prob
117
+ }
118
+ metadata: Dict[str, Union[str, float]] = {
119
+ "device": "gpu" if use_gpu and gpu_device else "cpu",
120
+ "inference_ms": batch_time * 1000 / len(pil_images),
121
+ "image_size": f"{pil_images[i].width}x{pil_images[i].height}"
122
+ }
123
+ results.append({
124
+ "prediction": prediction,
125
+ "metadata": metadata
126
+ })
127
+ return results
128
+
129
+ app: FastAPI = FastAPI(
130
+ title="Image Classifier API",
131
+ description="A FastAPI server with a Gradio UI for image classification. Supports batch processing",
132
+ )
133
+
134
+ app.add_middleware(CORSMiddleware,
135
+ allow_origins=["*"],
136
+ allow_credentials=True,
137
+ allow_methods=["*"],
138
+ allow_headers=["*"])
139
+
140
+ @app.post("/predict", response_model=dict)
141
+ async def predict_api(request: Base64Image, use_gpu: bool = True) -> dict:
142
+ try:
143
+ pil_image: Image.Image = base64_to_pil(request.image_data)
144
+ result: Dict[str, Union[dict, float]] = predict_batch([pil_image], use_gpu)[0]
145
+ return result
146
+ except Exception as e:
147
+ raise HTTPException(status_code=400, detail=str(e))
148
+
149
+ @app.post("/batch_predict", response_model=List[dict])
150
+ async def batch_predict_api(request: BatchBase64Images) -> List[dict]:
151
+ try:
152
+ pil_images: List[Image.Image] = []
153
+ for base64_str in request.image_data_list:
154
+ pil_images.append(base64_to_pil(base64_str))
155
+ results: List[Dict[str, Union[dict, float]]] = predict_batch(pil_images, request.use_gpu)
156
+ return results
157
+ except Exception as e:
158
+ raise HTTPException(status_code=400, detail=str(e))
159
+
160
+ def predict_from_pil(pil_image: Image.Image) -> Optional[dict]:
161
+ if pil_image is None:
162
+ return None
163
+ result: Dict[str, Union[dict, float]] = predict_batch([pil_image], use_gpu=True)[0]
164
+ return result["prediction"]
165
+
166
+ gradio_iface: gr.Interface = gr.Interface(
167
+ fn=predict_from_pil,
168
+ inputs=gr.Image(type="pil", label="Upload Image"),
169
+ outputs=gr.Label(num_top_classes=2, label="Predictions"),
170
+ title="Image Classifier",
171
+ description="Upload an image to see its classification. The API is available at the /docs endpoint.",
172
+ allow_flagging="never"
173
+ )
174
+
175
+ app = gr.mount_gradio_app(app, gradio_iface, path="/")