Nightfury16 commited on
Commit
8317439
·
1 Parent(s): 15f0b59

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/*.pth filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ ENV TRANSFORMERS_CACHE=/data/.cache/transformers
4
+ ENV HF_HOME=/data/.cache/huggingface
5
+ ENV MPLCONFIGDIR=/data/.cache/matplotlib
6
+
7
+ WORKDIR /code
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY . .
13
+
14
+ RUN mkdir -p checkpoints
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Room Type Classifier
3
+ emoji: 🏠
4
+ colorFrom: blue
5
+ colorTo: orange
6
+ sdk: docker
7
+ app_file: main.py
8
+ ---
9
+
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import yaml
5
+ from torchvision import models, transforms
6
+ from PIL import Image
7
+ import gradio as gr
8
+ from transformers import ConvNextV2ForImageClassification
9
+
10
+ CHECKPOINT_PATH = "checkpoints/room_classifier_best.pth"
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ class HFConvNeXtWrapper(nn.Module):
14
+ def __init__(self, model_name, num_labels):
15
+ super(HFConvNeXtWrapper, self).__init__()
16
+ self.model = ConvNextV2ForImageClassification.from_pretrained(
17
+ model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
18
+ def forward(self, x):
19
+ return self.model(x).logits
20
+
21
+ def get_model(model_name, num_classes):
22
+ if model_name.startswith("efficientnet"):
23
+ model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None)
24
+ num_ftrs = model.classifier[1].in_features
25
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
26
+ elif "convnextv2" in model_name:
27
+ model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
28
+ elif model_name == "vit_b_16":
29
+ model = models.vit_b_16(weights=None)
30
+ model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
31
+ else:
32
+ raise ValueError(f"Unknown model: {model_name}")
33
+ return model
34
+
35
+ if not os.path.exists(CHECKPOINT_PATH):
36
+ raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}")
37
+
38
+ print(f"Loading model from {CHECKPOINT_PATH}...")
39
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
40
+ model_name = checkpoint['model_name']
41
+ num_classes = checkpoint.get('num_classes', 5)
42
+
43
+ class_to_idx = checkpoint.get('class_to_idx', None)
44
+ if class_to_idx:
45
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
46
+ else:
47
+ print("Warning: class_to_idx not found in checkpoint. Using default 5 classes.")
48
+ idx_to_class = {0: 'Bathroom', 1: 'Bedroom', 2: 'Dining', 3: 'Kitchen', 4: 'Living'}
49
+
50
+ model = get_model(model_name, num_classes)
51
+ model.load_state_dict(checkpoint['state_dict'])
52
+ model.to(DEVICE)
53
+ model.eval()
54
+
55
+ inference_transform = transforms.Compose([
56
+ transforms.Resize((224, 224)),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
+ ])
60
+
61
+ def predict(pil_image):
62
+ if pil_image is None: return None
63
+ pil_image = pil_image.convert("RGB")
64
+ tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
65
+
66
+ with torch.no_grad():
67
+ logits = model(tensor)
68
+ probs = torch.softmax(logits, dim=1).squeeze()
69
+
70
+ return {idx_to_class[i]: float(probs[i]) for i in range(len(probs))}
71
+
72
+ iface = gr.Interface(
73
+ fn=predict,
74
+ inputs=gr.Image(type="pil", label="Upload Room Image"),
75
+ outputs=gr.Label(num_top_classes=5, label="Predictions"),
76
+ title="Room Type Classifier 🏠",
77
+ description=f"Classifies images into: {', '.join(idx_to_class.values())}",
78
+ )
79
+
80
+ if __name__ == "__main__":
81
+ iface.launch()
checkpoints/room_efficientnet_b0_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5ef02cf69916538affef9db11123ae3eecdb2175478284ad5341bfed055ebe5
3
+ size 16360826
cm_config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ data_params:
2
+ image_size: 224
3
+
4
+ model_params:
5
+ name: "efficientnet_b0"
6
+ num_classes: 5
7
+
8
+ output_params:
9
+ save_dir: "checkpoints"
10
+ checkpoint_name: "room_efficientnet_b0_best.pth"
main.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import yaml
5
+ from torchvision import models, transforms
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import base64
9
+ import io
10
+ import time
11
+ import threading
12
+ from typing import List, Dict, Union, Optional
13
+ from fastapi import FastAPI, HTTPException
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel
16
+ from transformers import ConvNextV2ForImageClassification
17
+
18
+ CHECKPOINT_DIR = "checkpoints"
19
+ CONFIG_PATH = "cm_config.yaml"
20
+
21
+ MODELS = {}
22
+ LABELS = {}
23
+
24
+ class HFConvNeXtWrapper(nn.Module):
25
+ def __init__(self, model_name, num_labels):
26
+ super(HFConvNeXtWrapper, self).__init__()
27
+ self.model = ConvNextV2ForImageClassification.from_pretrained(
28
+ model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
29
+ def forward(self, x):
30
+ return self.model(x).logits
31
+
32
+ def get_model(model_name, num_classes):
33
+ if model_name.startswith("efficientnet"):
34
+ model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None)
35
+ num_ftrs = model.classifier[1].in_features
36
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
37
+ elif "convnextv2" in model_name:
38
+ model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
39
+ elif model_name == "vit_b_16":
40
+ model = models.vit_b_16(weights=None)
41
+ model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
42
+ else:
43
+ raise ValueError(f"Unknown model: {model_name}")
44
+ return model
45
+
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ if not os.path.exists(CHECKPOINT_DIR):
49
+ os.makedirs(CHECKPOINT_DIR)
50
+
51
+ model_files = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')]
52
+ default_model_name = None
53
+
54
+ print(f"--- Loading models from {CHECKPOINT_DIR} ---")
55
+ for filename in model_files:
56
+ path = os.path.join(CHECKPOINT_DIR, filename)
57
+ try:
58
+ ckpt = torch.load(path, map_location=device)
59
+ m_name = ckpt.get('model_name', 'efficientnet_b0')
60
+ n_classes = ckpt.get('num_classes', 5)
61
+
62
+ model = get_model(m_name, n_classes)
63
+ model.load_state_dict(ckpt['state_dict'])
64
+ model.to(device)
65
+ model.eval()
66
+
67
+ display_name = filename.replace('.pth', '')
68
+ MODELS[display_name] = model
69
+
70
+ if 'class_to_idx' in ckpt:
71
+ LABELS[display_name] = {v: k for k, v in ckpt['class_to_idx'].items()}
72
+ else:
73
+ LABELS[display_name] = {0:'Bat', 1:'Bed', 2:'Din', 3:'Kit', 4:'Liv'}
74
+
75
+ if default_model_name is None: default_model_name = display_name
76
+ print(f"Loaded: {display_name}")
77
+
78
+ except Exception as e:
79
+ print(f"Failed to load {filename}: {e}")
80
+
81
+ if not MODELS:
82
+ print("WARNING: No models loaded. Using Dummy for build.")
83
+ default_model_name = "dummy"
84
+
85
+ inference_transform = transforms.Compose([
86
+ transforms.Resize((224, 224)),
87
+ transforms.ToTensor(),
88
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
89
+ ])
90
+
91
+ class Base64Image(BaseModel):
92
+ image_data: str
93
+ model_name: Optional[str] = default_model_name
94
+
95
+ def base64_to_pil(base64_str: str) -> Image.Image:
96
+ if "base64," in base64_str: base64_str = base64_str.split("base64,")[1]
97
+ return Image.open(io.BytesIO(base64.b64decode(base64_str)))
98
+
99
+ def run_inference(pil_image, model_key):
100
+ if model_key not in MODELS:
101
+ raise ValueError("Model not found")
102
+
103
+ model = MODELS[model_key]
104
+ idx_map = LABELS[model_key]
105
+
106
+ img_tensor = inference_transform(pil_image.convert("RGB")).unsqueeze(0).to(device)
107
+
108
+ with torch.no_grad():
109
+ logits = model(img_tensor)
110
+ probs = torch.softmax(logits, dim=1).squeeze().tolist()
111
+
112
+ return {idx_map[i]: float(probs[i]) for i in range(len(probs))}
113
+
114
+ app = FastAPI(title="Room Type Classifier API")
115
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
116
+
117
+ @app.get("/")
118
+ def home():
119
+ return {"message": "Room Classifier API is running", "models": list(MODELS.keys())}
120
+
121
+ @app.post("/predict")
122
+ def predict_api(payload: Base64Image):
123
+ m_name = payload.model_name if payload.model_name else default_model_name
124
+ try:
125
+ img = base64_to_pil(payload.image_data)
126
+ result = run_inference(img, m_name)
127
+ return {"model": m_name, "predictions": result}
128
+ except Exception as e:
129
+ raise HTTPException(status_code=500, detail=str(e))
130
+
131
+ def predict_gradio(img, model_choice):
132
+ if img is None: return None
133
+ return run_inference(img, model_choice)
134
+
135
+ if MODELS:
136
+ gradio_iface = gr.Interface(
137
+ fn=predict_gradio,
138
+ inputs=[
139
+ gr.Image(type="pil", label="Image"),
140
+ gr.Dropdown(choices=list(MODELS.keys()), value=default_model_name, label="Model")
141
+ ],
142
+ outputs=gr.Label(num_top_classes=5),
143
+ title="Room Type Classifier",
144
+ description="Detects: Bathroom, Bedroom, Dining, Kitchen, Living",
145
+ allow_flagging="never"
146
+ )
147
+ app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ fastapi
4
+ uvicorn
5
+ gradio
6
+ PyYAML
7
+ python-multipart
8
+ pydantic
9
+ transformers