SpyC0der77 commited on
Commit
7168a17
·
verified ·
1 Parent(s): 4a63e1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -217
app.py CHANGED
@@ -1,36 +1,60 @@
1
- import gradio as gr
2
  import torch
3
- from PIL import Image
4
- import torchvision.transforms as transforms
5
  import torch.nn as nn
6
- from torchvision import models
7
- from typing import Dict, Tuple
8
- import os
9
-
10
-
11
- class MultiOutputModel(nn.Module):
12
- """Multi-output model for artifact classification (matches UI)"""
13
-
14
- def __init__(self, num_object_classes, num_material_classes, hidden_size=512):
15
- super(MultiOutputModel, self).__init__()
16
-
17
- # Use a pre-trained ResNet as backbone
18
- self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
19
- # Remove the final classification layer
20
- self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
21
-
22
- # Freeze early layers for transfer learning
23
- for param in list(self.backbone.parameters())[:-4]: # Unfreeze more layers for better fine-tuning
24
- param.requires_grad = False
25
-
26
- # Classification heads for each attribute
27
- self.object_classifier = nn.Linear(2048, num_object_classes)
28
- self.material_classifier = nn.Linear(2048, num_material_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def forward(self, x):
31
  # Extract features using backbone
32
  features = self.backbone(x)
33
- features = features.view(features.size(0), -1)
 
 
 
34
 
35
  # Get predictions for each attribute
36
  object_pred = self.object_classifier(features)
@@ -41,211 +65,67 @@ class MultiOutputModel(nn.Module):
41
  'material': material_pred,
42
  }
43
 
44
-
45
-
46
-
47
- def load_model(model_path: str) -> Tuple[torch.nn.Module, Dict[str, Dict[int, str]]]:
48
- """Load the model from checkpoint and return model and label mappings."""
49
- print(f"Loading model from {model_path}...")
50
- checkpoint = torch.load(model_path, map_location="cpu")
51
-
52
- # Get label mappings to determine number of classes
53
- label_mappings = checkpoint.get('label_mappings', {})
54
- num_object_classes = len(label_mappings.get('object_name', {}))
55
- num_material_classes = len(label_mappings.get('material', {}))
56
-
57
- if num_object_classes == 0:
58
- print("Warning: No label mappings found, using fallback class counts")
59
- num_object_classes, num_material_classes = 1018, 192
60
-
61
- # Check model type based on state_dict keys to determine which architecture to use
62
- model_state_dict = checkpoint.get('model_state_dict', {})
63
- state_dict_keys = set(model_state_dict.keys())
64
-
65
- # Only support v1 model (MultiOutputModel) with ResNet backbone
66
- print(f"Loading v1 model (MultiOutputModel) with ResNet backbone")
67
- model = MultiOutputModel(num_object_classes, num_material_classes)
68
-
69
- # Load state dict
70
- if 'model_state_dict' in checkpoint:
71
- model.load_state_dict(checkpoint['model_state_dict'])
72
- else:
73
- print("Warning: No model_state_dict found in checkpoint")
74
-
75
- # Create reverse mappings (id2label)
76
- reverse_mappings = {}
77
- for attr, mapping in label_mappings.items():
78
- reverse_mappings[attr] = {int(v): str(k) for k, v in mapping.items()}
79
- print(f"Loaded {attr} mappings: {len(reverse_mappings[attr])} classes")
80
-
81
- return model, reverse_mappings
82
-
83
-
84
- def run_inference(model: torch.nn.Module, pixel_values: torch.Tensor, device: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
85
- """Run inference on pixel_values and return predictions and confidences for both object_name and material."""
86
- model.eval()
87
- model.to(device)
88
- pixel_values = pixel_values.to(device)
89
-
90
- with torch.no_grad():
91
- outputs = model(pixel_values)
92
-
93
- # Handle different output formats
94
- if isinstance(outputs, dict):
95
- # Multi-output model format
96
- if 'object_name' in outputs and 'material' in outputs:
97
- logits_obj = outputs['object_name']
98
- logits_mat = outputs['material']
99
- else:
100
- raise ValueError("Expected 'object_name' and 'material' in model outputs")
101
- else:
102
- raise ValueError("Expected dict output with 'object_name' and 'material' keys")
103
-
104
- preds_obj = torch.argmax(logits_obj, dim=-1)
105
- probs_obj = torch.softmax(logits_obj, dim=-1)
106
- max_probs_obj = torch.max(probs_obj, dim=-1)[0]
107
-
108
- preds_mat = torch.argmax(logits_mat, dim=-1)
109
- probs_mat = torch.softmax(logits_mat, dim=-1)
110
- max_probs_mat = torch.max(probs_mat, dim=-1)[0]
111
-
112
- return preds_obj.cpu(), max_probs_obj.cpu(), preds_mat.cpu(), max_probs_mat.cpu()
113
-
114
-
115
- # Global variables for model and label mappings
116
- model = None
117
- label_mappings = None
118
- device = None
119
-
120
- def preprocess_image(image: Image.Image) -> torch.Tensor:
121
- """Preprocess image for model inference."""
122
- # Define transforms
123
- transform = transforms.Compose([
124
  transforms.Resize(256),
125
  transforms.CenterCrop(224),
126
  transforms.ToTensor(),
127
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
128
  ])
129
 
130
- # Apply transforms
131
- image = image.convert('RGB')
132
- tensor = transform(image).unsqueeze(0) # Add batch dimension
133
-
134
- return tensor
135
-
136
- def predict_artifact(image: Image.Image) -> tuple[str, float, str, float]:
137
- """Predict object and material from image."""
138
- global model, label_mappings, device
139
-
140
- if model is None:
141
- raise ValueError("Model not loaded. Please restart the application.")
142
-
143
- # Preprocess image
144
- pixel_values = preprocess_image(image)
145
-
146
- # Run inference
147
- preds_obj, confs_obj, preds_mat, confs_mat = run_inference(model, pixel_values, device)
148
-
149
- # Get predictions
150
- object_pred_id = preds_obj[0].item()
151
- material_pred_id = preds_mat[0].item()
152
- object_conf = confs_obj[0].item()
153
- material_conf = confs_mat[0].item()
154
-
155
- # Convert IDs to labels
156
- object_name = label_mappings['object_name'].get(object_pred_id, f"class_{object_pred_id}")
157
- material_name = label_mappings['material'].get(material_pred_id, f"class_{material_pred_id}")
158
 
159
- return object_name, object_conf, material_name, material_conf
 
 
 
 
 
160
 
161
- def gradio_predict(image):
162
- """Gradio interface function."""
163
  if image is None:
164
- return "Please upload an image", "", "", ""
165
-
166
- try:
167
- object_name, object_conf, material_name, material_conf = predict_artifact(image)
168
 
169
- # Format results
170
- object_result = f"**{object_name}** ({object_conf:.1%} confidence)"
171
- material_result = f"**{material_name}** ({material_conf:.1%} confidence)"
 
172
 
173
- return object_result, material_result, f"{object_conf:.3f}", f"{material_conf:.3f}"
 
 
 
174
 
175
- except Exception as e:
176
- return f"Error: {str(e)}", "", "", ""
 
177
 
178
- def load_model_on_startup():
179
- """Load model when the application starts."""
180
- global model, label_mappings, device
181
 
182
- # Set device
183
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
184
-
185
- # Load model from model.pth
186
- model_path = "model.pth"
187
- if not os.path.exists(model_path):
188
- print(f"Warning: Model file not found at {model_path}")
189
- print("Please ensure the model.pth file exists in the current directory before running the application.")
190
- return
191
-
192
- try:
193
- model, label_mappings = load_model(model_path)
194
- print("Model loaded successfully!")
195
- print(f"Object classes: {len(label_mappings.get('object_name', {}))}")
196
- print(f"Material classes: {len(label_mappings.get('material', {}))}")
197
- except Exception as e:
198
- print(f"Error loading model: {e}")
199
-
200
- # Load model on startup
201
- load_model_on_startup()
202
-
203
- # Create Gradio interface
204
- with gr.Blocks(title="Artifact Classification v1", theme=gr.themes.Soft()) as demo:
205
- gr.Markdown("# 🏺 Artifact Classification Model v1")
206
- gr.Markdown("Upload an image of an artifact to classify its **object type** and **material composition**.")
207
 
208
  with gr.Row():
209
- with gr.Column():
210
- image_input = gr.Image(label="Upload Artifact Image", type="pil")
211
- submit_btn = gr.Button("🔍 Classify Artifact", variant="primary")
212
-
213
- with gr.Column():
214
- gr.Markdown("### 📊 Classification Results")
215
-
216
- object_output = gr.Markdown(label="**Object Type**")
217
- material_output = gr.Markdown(label="**Material**")
218
-
219
- with gr.Accordion("📈 Confidence Scores", open=False):
220
- object_conf = gr.Textbox(label="Object Confidence", interactive=False)
221
- material_conf = gr.Textbox(label="Material Confidence", interactive=False)
222
-
223
- # Connect the interface
224
- submit_btn.click(
225
- fn=gradio_predict,
226
- inputs=image_input,
227
- outputs=[object_output, material_output, object_conf, material_conf]
228
- )
229
-
230
- # Example images
231
- gr.Examples(
232
- examples=[
233
- # You can add example image paths here if available
234
- ],
235
- inputs=image_input,
236
- outputs=[object_output, material_output, object_conf, material_conf],
237
- fn=gradio_predict,
238
- cache_examples=False
239
- )
240
-
241
- gr.Markdown("""
242
- ### ℹ️ About
243
- This model uses a ResNet-50 backbone to classify museum artifacts into object types (vase, statue, pottery, etc.)
244
- and material compositions (ceramic, bronze, stone, etc.).
245
-
246
- **Model**: MultiOutputModel with ResNet-50 backbone
247
- **Training Data**: Oriental Museum artifacts dataset
248
- """)
249
 
250
  if __name__ == "__main__":
251
- demo.launch()
 
 
1
  import torch
 
 
2
  import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import timm
7
+
8
+ class ImprovedMultiOutputModel(nn.Module):
9
+ """Improved multi-output model with EfficientNet backbone."""
10
+ def __init__(self, num_object_classes, num_material_classes, backbone='efficientnet_b0'):
11
+ super(ImprovedMultiOutputModel, self).__init__()
12
+
13
+ # Use EfficientNet backbone
14
+ self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0)
15
+ backbone_out_features = self.backbone.num_features
16
+
17
+ # Add attention mechanism
18
+ self.attention = nn.Sequential(
19
+ nn.Linear(backbone_out_features, 512),
20
+ nn.ReLU(),
21
+ nn.Dropout(0.1),
22
+ nn.Linear(512, backbone_out_features),
23
+ nn.Sigmoid()
24
+ )
25
+
26
+ # Improved classification heads with dropout and batch norm
27
+ self.object_classifier = nn.Sequential(
28
+ nn.Linear(backbone_out_features, 1024),
29
+ nn.BatchNorm1d(1024),
30
+ nn.ReLU(),
31
+ nn.Dropout(0.3),
32
+ nn.Linear(1024, 512),
33
+ nn.BatchNorm1d(512),
34
+ nn.ReLU(),
35
+ nn.Dropout(0.2),
36
+ nn.Linear(512, num_object_classes)
37
+ )
38
+
39
+ self.material_classifier = nn.Sequential(
40
+ nn.Linear(backbone_out_features, 1024),
41
+ nn.BatchNorm1d(1024),
42
+ nn.ReLU(),
43
+ nn.Dropout(0.3),
44
+ nn.Linear(1024, 512),
45
+ nn.BatchNorm1d(512),
46
+ nn.ReLU(),
47
+ nn.Dropout(0.2),
48
+ nn.Linear(512, num_material_classes)
49
+ )
50
 
51
  def forward(self, x):
52
  # Extract features using backbone
53
  features = self.backbone(x)
54
+
55
+ # Apply attention mechanism
56
+ attention_weights = self.attention(features)
57
+ features = features * attention_weights
58
 
59
  # Get predictions for each attribute
60
  object_pred = self.object_classifier(features)
 
65
  'material': material_pred,
66
  }
67
 
68
+ def get_val_transforms():
69
+ """Get transforms for validation."""
70
+ return transforms.Compose([
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  transforms.Resize(256),
72
  transforms.CenterCrop(224),
73
  transforms.ToTensor(),
74
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
75
  ])
76
 
77
+ def load_model(model_path):
78
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ checkpoint = torch.load(model_path, map_location=device)
80
+ label_mappings = checkpoint['label_mappings']
81
+ num_object_classes = len(label_mappings['object_name'])
82
+ num_material_classes = len(label_mappings['material'])
83
+ backbone = 'efficientnet_b0'
84
+ model = ImprovedMultiOutputModel(num_object_classes, num_material_classes, backbone)
85
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
86
+ model.to(device)
87
+ model.eval()
88
+ return model, label_mappings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # Load models
91
+ models = {}
92
+ models['modelv1.pth'], label_mappings_v1 = load_model('modelv1.pth')
93
+ models['modelv2.pth'], label_mappings_v2 = load_model('modelv2.pth')
94
+ # Assume label_mappings are the same for both, use v1
95
+ label_mappings = label_mappings_v1
96
 
97
+ def predict(image, model_choice):
 
98
  if image is None:
99
+ return "Please upload an image."
 
 
 
100
 
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+ model = models[model_choice]
103
+ transform = get_val_transforms()
104
+ image_tensor = transform(image).unsqueeze(0).to(device)
105
 
106
+ with torch.no_grad():
107
+ outputs = model(image_tensor)
108
+ pred_obj = torch.argmax(outputs['object_name'], dim=1).item()
109
+ pred_mat = torch.argmax(outputs['material'], dim=1).item()
110
 
111
+ # Map IDs back to names
112
+ obj_name = [k for k, v in label_mappings['object_name'].items() if v == pred_obj][0]
113
+ mat_name = [k for k, v in label_mappings['material'].items() if v == pred_mat][0]
114
 
115
+ return f"Predicted Object: {obj_name}\nPredicted Material: {mat_name}"
 
 
116
 
117
+ # Create Gradio interface using Blocks
118
+ with gr.Blocks(title="Artifact Classification Model") as demo:
119
+ gr.Markdown("# Artifact Classification Model")
120
+ gr.Markdown("Upload an image to classify the object name and material.")
121
+ model_selector = gr.Dropdown(choices=['modelv1.pth', 'modelv2.pth'], label="Select Model", value='modelv1.pth')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  with gr.Row():
124
+ input_image = gr.Image(type="pil", label="Upload an Image")
125
+ output_text = gr.Textbox(label="Predictions")
126
+
127
+ predict_btn = gr.Button("Predict")
128
+ predict_btn.click(fn=predict, inputs=[input_image, model_selector], outputs=output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  if __name__ == "__main__":
131
+ demo.launch()