Nightfury16 commited on
Commit
7011a64
·
1 Parent(s): 3579e8e

Initial commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ ENV TRANSFORMERS_CACHE=/data/.cache/transformers
4
+ ENV HF_HOME=/data/.cache/huggingface
5
+ ENV MPLCONFIGDIR=/data/.cache/matplotlib
6
+
7
+ WORKDIR /code
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+ COPY . .
12
+
13
+ EXPOSE 7860
14
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Clean Vs Messy
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: docker
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Messy vs Clean Image Classifier
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: green
6
  sdk: docker
7
+ app_file: main.py
8
  ---
9
 
10
+ Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
4
+ os.environ['HF_HOME'] = '/data/.cache/huggingface'
5
+ os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import yaml
10
+ from torchvision import models, transforms
11
+ from PIL import Image
12
+ import gradio as gr
13
+ from transformers import ConvNextV2ForImageClassification
14
+ from typing import Dict, Tuple
15
+
16
+ MODEL_CHECKPOINTS = {
17
+ "convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth",
18
+ "efficientnet_b0": "checkpoints/effnet_b0_best.pth",
19
+ "efficientnet_b3": "checkpoints/effnet_b3_best.pth",
20
+ "vit_b_16": "checkpoints/vit_b_16_best.pth"
21
+ }
22
+ DEFAULT_MODEL_NAME = "vit_b_16"
23
+
24
+ MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
25
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ class HFConvNeXtWrapper(nn.Module):
28
+ def __init__(self, model_name, num_labels):
29
+ super(HFConvNeXtWrapper, self).__init__()
30
+ self.model = ConvNextV2ForImageClassification.from_pretrained(
31
+ model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
32
+ def forward(self, x):
33
+ return self.model(x).logits
34
+
35
+ def get_model(model_name: str, num_classes: int) -> nn.Module:
36
+ model = None
37
+ if model_name == "efficientnet_b0":
38
+ model = models.efficientnet_b0(weights=None)
39
+ num_ftrs = model.classifier[1].in_features
40
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
41
+ elif model_name == "efficientnet_b3":
42
+ model = models.efficientnet_b3(weights=None)
43
+ num_ftrs = model.classifier[1].in_features
44
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
45
+ elif model_name == "vit_b_16":
46
+ model = models.vit_b_16(weights=None)
47
+ num_ftrs = model.heads.head.in_features
48
+ model.heads.head = nn.Linear(num_ftrs, num_classes)
49
+ elif "convnextv2" in model_name:
50
+ model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
51
+ else:
52
+ raise ValueError(f"Model '{model_name}' not supported.")
53
+ return model
54
+
55
+ def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
56
+ if not os.path.exists(checkpoint_path):
57
+ raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
58
+ checkpoint = torch.load(checkpoint_path, map_location=device)
59
+ model_name_from_ckpt = checkpoint['model_name']
60
+ model = get_model(model_name_from_ckpt, num_classes=1)
61
+ model.load_state_dict(checkpoint['state_dict'])
62
+ model.to(device)
63
+ model.eval()
64
+ # The idx_to_class is no longer needed as we hardcode labels
65
+ return model, {}
66
+
67
+ print("--- Loading all models into memory ---")
68
+ for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
69
+ if os.path.exists(ckpt_path):
70
+ model, _ = load_checkpoint(ckpt_path, DEVICE)
71
+ MODELS[display_name] = model
72
+ print(f"Loaded '{display_name}' on {DEVICE}.")
73
+ else:
74
+ print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.")
75
+
76
+ if not MODELS:
77
+ raise RuntimeError("No models were loaded. Please check your checkpoints directory.")
78
+
79
+ with open('cm_config.yaml', 'r') as f:
80
+ config = yaml.safe_load(f)
81
+ IMG_SIZE = config['data_params']['image_size']
82
+ inference_transform = transforms.Compose([
83
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
84
+ transforms.ToTensor(),
85
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
86
+ ])
87
+
88
+ def predict(pil_image, model_name: str):
89
+ if pil_image is None: return None
90
+
91
+ model = MODELS[model_name]
92
+ pil_image = pil_image.convert("RGB")
93
+ image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
94
+
95
+ with torch.no_grad():
96
+ output = model(image_tensor)
97
+ prob = torch.sigmoid(output).item()
98
+
99
+ # Per user request: Class 0 is "clean", Class 1 is "messy"
100
+ return {"clean": 1 - prob, "messy": prob}
101
+
102
+ iface = gr.Interface(
103
+ fn=predict,
104
+ inputs=[
105
+ gr.Image(type="pil", label="Upload Image"),
106
+ gr.Dropdown(
107
+ choices=list(MODELS.keys()),
108
+ value=DEFAULT_MODEL_NAME,
109
+ label="Select Model"
110
+ )
111
+ ],
112
+ outputs=gr.Label(num_top_classes=2, label="Predictions"),
113
+ title="Messy vs Clean Image Classifier",
114
+ description="Upload an image and select a model to see its classification for 'messy' vs 'clean'.",
115
+ )
116
+
117
+ iface.launch()
checkpoints/convnext_v2_atto_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d04e828a64aa572a9b9ef741d8a083bf89be2e669a065cca8d3e49f9c69c6da3
3
+ size 111553930
checkpoints/effnet_b0_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b6c672c601de9710c9aa39b93cce5fd3a3332748aadb5a0d3ac878e75602ae5
3
+ size 16336022
checkpoints/effnet_b3_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8012eb905fe5ea97301a1737ad5f340bcac733aa036edf737af0ed4f677cfcb
3
+ size 43350212
checkpoints/vit_b_16_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad719487e643001a11294878db6d7336cacfe4d7e61b31272c27919d4b896e3b
3
+ size 343259114
cm.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_params:
2
+ data_path: "dataset"
3
+ image_size: 224
4
+
5
+ model_params:
6
+ name: "efficientnet_b0"
7
+ pretrained: True
8
+
9
+ train_params:
10
+ epochs: 25
11
+ batch_size: 64
12
+ optimizer: "AdamW"
13
+ learning_rate: 0.001
14
+ unfreeze_epoch: 5
15
+
16
+ ddp_params:
17
+ master_port: '12355'
18
+
19
+ output_params:
20
+ save_dir: "runs/staging_classifier"
21
+ checkpoint_name: "best_model.pth"
convnext_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ConvNextV2ForImageClassification"
4
+ ],
5
+ "depths": [
6
+ 3,
7
+ 3,
8
+ 9,
9
+ 3
10
+ ],
11
+ "drop_path_rate": 0.1,
12
+ "hidden_act": "gelu",
13
+ "hidden_sizes": [
14
+ 96,
15
+ 192,
16
+ 384,
17
+ 768
18
+ ],
19
+ "image_size": 224,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_eps": 1e-06,
22
+ "model_type": "convnextv2",
23
+ "num_channels": 3,
24
+ "num_stages": 4,
25
+ "patch_size": 4,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.35.2"
28
+ }
main.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ os.environ['HF_HOME'] = './hf_cache'
4
+ os.environ['MPLCONFIGDIR'] = './mpl_cache'
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import yaml
9
+ from torchvision import models, transforms
10
+ from PIL import Image
11
+ import gradio as gr
12
+ import base64
13
+ import io
14
+ import time
15
+ import threading
16
+ from typing import List, Dict, Union, Tuple, Optional
17
+
18
+ from fastapi import FastAPI, HTTPException
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from pydantic import BaseModel
21
+ from transformers import ConvNextV2Config, ConvNextV2ForImageClassification
22
+
23
+ MODEL_CHECKPOINTS = {
24
+ "convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth",
25
+ "efficientnet_b0": "checkpoints/effnet_b0_best.pth",
26
+ "efficientnet_b3": "checkpoints/effnet_b3_best.pth",
27
+ "vit_b_16": "checkpoints/vit_b_16_best.pth"
28
+ }
29
+ DEFAULT_MODEL_NAME = "vit_b_16"
30
+
31
+ CONVNEXT_CONFIG_PATH = "convnext_config.json"
32
+
33
+ GPU_MODELS: Dict[str, nn.Module] = {}
34
+ CPU_MODELS: Dict[str, nn.Module] = {}
35
+ CONFIG_PATH: str = os.getenv('CONFIG_PATH', 'cm_config.yaml')
36
+ model_lock: threading.Lock = threading.Lock()
37
+
38
+ def get_model(model_name: str, num_classes: int) -> nn.Module:
39
+ model: Optional[nn.Module] = None
40
+ if model_name == "efficientnet_b0":
41
+ model = models.efficientnet_b0(weights=None)
42
+ num_ftrs = model.classifier[1].in_features
43
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
44
+ elif model_name == "efficientnet_b3":
45
+ model = models.efficientnet_b3(weights=None)
46
+ num_ftrs = model.classifier[1].in_features
47
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
48
+ elif model_name == "vit_b_16":
49
+ model = models.vit_b_16(weights=None)
50
+ num_ftrs = model.heads.head.in_features
51
+ model.heads.head = nn.Linear(num_ftrs, num_classes)
52
+ elif "convnextv2" in model_name:
53
+ config = ConvNextV2Config.from_json_file(CONVNEXT_CONFIG_PATH)
54
+ config.num_labels = num_classes
55
+ model = ConvNextV2ForImageClassification(config)
56
+ else:
57
+ raise ValueError(f"Model '{model_name}' not supported.")
58
+ return model
59
+
60
+ def load_checkpoint(checkpoint_path: str, device: torch.device) -> nn.Module:
61
+ if not os.path.exists(checkpoint_path):
62
+ raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
63
+ checkpoint: dict = torch.load(checkpoint_path, map_location=device)
64
+ model_name_from_ckpt: str = checkpoint['model_name']
65
+ state_dict = checkpoint['state_dict']
66
+
67
+ if any(key.startswith("model.") for key in state_dict.keys()):
68
+ print(f" > Unwrapping state_dict for {model_name_from_ckpt}...")
69
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
70
+
71
+ model: nn.Module = get_model(model_name_from_ckpt, num_classes=1)
72
+ model.load_state_dict(state_dict)
73
+ model.to(device)
74
+ model.eval()
75
+ return model
76
+
77
+ print("--- Loading all models into memory ---")
78
+ cpu_device = torch.device("cpu")
79
+ gpu_device = torch.device("cuda") if torch.cuda.is_available() else None
80
+
81
+ for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
82
+ if os.path.exists(ckpt_path):
83
+ print(f"Loading '{display_name}'...")
84
+ try:
85
+ cpu_model = load_checkpoint(ckpt_path, cpu_device)
86
+ CPU_MODELS[display_name] = cpu_model
87
+ print(f" > Loaded '{display_name}' for CPU.")
88
+ if gpu_device:
89
+ gpu_model = load_checkpoint(ckpt_path, gpu_device)
90
+ GPU_MODELS[display_name] = gpu_model
91
+ print(f" > Loaded '{display_name}' for GPU.")
92
+ except Exception as e:
93
+ print(f" > FAILED to load '{display_name}'. Error: {e}")
94
+ else:
95
+ print(f"WARNING: Checkpoint for '{display_name}' not found at {ckpt_path}. It will not be available.")
96
+
97
+ if not CPU_MODELS:
98
+ raise RuntimeError("No models were loaded. Please check the `checkpoints` directory.")
99
+
100
+ try:
101
+ with open(CONFIG_PATH, 'r') as f: config: dict = yaml.safe_load(f)
102
+ except FileNotFoundError:
103
+ raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'.")
104
+
105
+ IMG_SIZE: int = config['data_params']['image_size']
106
+ inference_transform = transforms.Compose([
107
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
108
+ transforms.ToTensor(),
109
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
110
+ ])
111
+
112
+ def base64_to_pil(base64_str: str) -> Image.Image:
113
+ try:
114
+ if "base64," in base64_str: base64_str = base64_str.split("base64,")[1]
115
+ image_data: bytes = base64.b64decode(base64_str)
116
+ return Image.open(io.BytesIO(image_data))
117
+ except Exception as e:
118
+ raise ValueError(f"Invalid base64 string: {e}")
119
+
120
+ class Base64Image(BaseModel): image_data: str
121
+ class BatchBase64Images(BaseModel):
122
+ image_data_list: List[str]
123
+ model_name: str = DEFAULT_MODEL_NAME
124
+ use_gpu: bool = True
125
+
126
+ def predict_batch(pil_images: List[Image.Image], use_gpu: bool, model_name: str) -> List[Dict[str, Union[dict, float]]]:
127
+ model_dict = GPU_MODELS if use_gpu and gpu_device else CPU_MODELS
128
+ if model_name not in model_dict:
129
+ raise ValueError(f"Model '{model_name}' not loaded or not available. Available: {list(model_dict.keys())}")
130
+
131
+ model = model_dict[model_name]
132
+ device = gpu_device if use_gpu and gpu_device else cpu_device
133
+
134
+ image_tensors = [inference_transform(img.convert("RGB")) for img in pil_images]
135
+ batch_tensor = torch.stack(image_tensors).to(device)
136
+
137
+ with model_lock, torch.no_grad():
138
+ start_time = time.time()
139
+ output_obj = model(batch_tensor)
140
+ batch_time = time.time() - start_time
141
+
142
+ if hasattr(output_obj, 'logits'):
143
+ logits = output_obj.logits
144
+ else:
145
+ logits = output_obj
146
+
147
+ results = []
148
+ probs = torch.sigmoid(logits).squeeze().tolist()
149
+ if not isinstance(probs, list): probs = [probs]
150
+
151
+ for prob in probs:
152
+ results.append({
153
+ "prediction": {"clean": 1 - prob, "messy": prob},
154
+ "metadata": {"device": str(device), "inference_ms": (batch_time * 1000) / len(pil_images)}
155
+ })
156
+ return results
157
+
158
+ app = FastAPI(title="Messy vs Clean Image Classifier API")
159
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
160
+
161
+ @app.post("/predict", response_model=dict)
162
+ async def predict_api(request: Base64Image, model_name: str = DEFAULT_MODEL_NAME, use_gpu: bool = True):
163
+ try:
164
+ pil_image = base64_to_pil(request.image_data)
165
+ return predict_batch([pil_image], use_gpu, model_name)[0]
166
+ except Exception as e:
167
+ raise HTTPException(status_code=400, detail=str(e))
168
+
169
+ @app.post("/batch_predict", response_model=List[dict])
170
+ async def batch_predict_api(request: BatchBase64Images):
171
+ try:
172
+ pil_images = [base64_to_pil(b64) for b64 in request.image_data_list]
173
+ return predict_batch(pil_images, request.use_gpu, request.model_name)
174
+ except Exception as e:
175
+ raise HTTPException(status_code=400, detail=str(e))
176
+
177
+ @app.get("/models", response_model=List[str])
178
+ async def get_available_models():
179
+ return list(CPU_MODELS.keys())
180
+
181
+ def predict_gradio(pil_image: Image.Image, model_name: str) -> Optional[dict]:
182
+ if pil_image is None: return None
183
+ result = predict_batch([pil_image], use_gpu=True, model_name=model_name)[0]
184
+ return result["prediction"]
185
+
186
+ gradio_iface = gr.Interface(
187
+ fn=predict_gradio,
188
+ inputs=[
189
+ gr.Image(type="pil", label="Input Image", sources=["upload", "webcam", "clipboard"]),
190
+ gr.Dropdown(
191
+ choices=list(CPU_MODELS.keys()),
192
+ value=DEFAULT_MODEL_NAME,
193
+ label="Select Model"
194
+ )
195
+ ],
196
+ outputs=gr.Label(num_top_classes=2, label="Predictions"),
197
+ title="Messy vs Clean Image Classifier",
198
+ description="Upload an image and select a model to see its classification for 'messy' vs 'clean'. The API is available at the /docs endpoint.",
199
+ allow_flagging="never"
200
+ )
201
+
202
+ app = gr.mount_gradio_app(app, gradio_iface, path="/")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ fastapi==0.104.1
4
+ uvicorn==0.24.0
5
+ gradio==3.50.2
6
+ gradio-client==0.6.1
7
+ PyYAML==6.0.1
8
+ python-multipart==0.0.6
9
+ pydantic==2.5.2
10
+ transformers