ErdemAtak commited on
Commit
a9b298e
·
verified ·
1 Parent(s): d5d8e59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -234
app.py CHANGED
@@ -5,20 +5,19 @@ import torch.nn.functional as F
5
  from torchvision import models, transforms
6
  from PIL import Image
7
  import gradio as gr
8
- import pandas as pd
9
  import numpy as np
10
- from pathlib import Path
11
 
12
- # Check for GPU availability
13
- if torch.backends.mps.is_available():
14
- DEVICE = torch.device("mps")
15
- print(f"Using Metal GPU: {DEVICE}")
16
  else:
17
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- print(f"Metal GPU not available, using: {DEVICE}")
19
 
20
- # Model path
21
- MODEL_PATH = "models/model_final.pth"
 
22
 
23
  # Art styles (sorted alphabetically for class index consistency)
24
  ART_STYLES = [
@@ -32,7 +31,6 @@ ART_STYLES = [
32
 
33
  # Image preprocessing
34
  def preprocess_image(image):
35
- # Define the transformation
36
  transform = transforms.Compose([
37
  transforms.Resize(256),
38
  transforms.CenterCrop(224),
@@ -40,86 +38,101 @@ def preprocess_image(image):
40
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
41
  ])
42
 
43
- # Apply transformation
44
  image_tensor = transform(image).unsqueeze(0)
45
  return image_tensor
46
 
47
- # Load model
48
  def load_model():
49
- # Create ResNet34 model
50
- model = models.resnet34(weights=None)
51
- # Adjust the final layer for our classes
52
- model.fc = nn.Linear(512, len(ART_STYLES))
53
-
54
- # Load the state dictionary
55
- state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
56
- model.load_state_dict(state_dict)
57
-
58
- # Move model to device and set to evaluation mode
59
- model = model.to(DEVICE)
60
- model.eval()
61
-
62
- return model
 
 
 
 
 
 
 
 
63
 
64
  # Function to predict art style
65
  def predict_art_style(image, model):
66
- # Preprocess the image
67
- input_tensor = preprocess_image(image).to(DEVICE)
68
-
69
- # Make prediction
70
- with torch.no_grad():
71
- outputs = model(input_tensor)
72
- probabilities = F.softmax(outputs, dim=1)[0]
73
-
74
- # Get top 5 predictions
75
- top5_prob, top5_indices = torch.topk(probabilities, 5)
76
-
77
- # Create results
78
- results = []
79
- for i, (prob, idx) in enumerate(zip(top5_prob.cpu().numpy(), top5_indices.cpu().numpy())):
80
- style = ART_STYLES[idx]
81
- # Format style name for better display
82
- display_style = style.replace('_', ' ')
83
- results.append((display_style, float(prob), i == 0))
84
-
85
- return results
 
 
 
 
86
 
87
  # Main prediction function for Gradio
88
  def classify_image(image):
89
  if image is None:
90
- return None
91
 
92
- # Convert from BGR to RGB (if needed)
93
- if isinstance(image, np.ndarray):
94
- image = Image.fromarray(image)
95
-
96
- # Get model predictions
97
- predictions = predict_art_style(image, model)
98
-
99
- # Format predictions for display
100
- result_html = "<div style='font-size: 1.2rem; background-color: #f0f9ff; padding: 1rem; border-radius: 8px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);'>"
101
- result_html += "<h3 style='margin-bottom: 15px; color: #1e40af;'>Top 5 Predicted Art Styles:</h3>"
102
-
103
- # Add prediction bars
104
- for i, (style, prob, _) in enumerate(predictions):
105
- percentage = prob * 100
106
- bar_color = "#3b82f6" if i == 0 else "#93c5fd"
107
- result_html += f"<div style='margin-bottom: 10px;'>"
108
- result_html += f"<div style='display: flex; align-items: center; margin-bottom: 5px;'>"
109
- result_html += f"<span style='font-weight: {'bold' if i==0 else 'normal'}; width: 200px; font-size: 1.1rem;'>{style}</span>"
110
- result_html += f"<span style='margin-left: 10px; font-weight: {'bold' if i==0 else 'normal'}; width: 60px; text-align: right;'>{percentage:.1f}%</span>"
111
- result_html += "</div>"
112
- result_html += f"<div style='height: 10px; width: 100%; background-color: #e5e7eb; border-radius: 5px;'>"
113
- result_html += f"<div style='height: 100%; width: {percentage}%; background-color: {bar_color}; border-radius: 5px;'></div>"
114
- result_html += "</div>"
 
 
 
115
  result_html += "</div>"
116
-
117
- result_html += "</div>"
118
-
119
- # Get top prediction for style info
120
- top_style = predictions[0][0]
121
-
122
- return result_html, top_style
 
123
 
124
  # Interpretation function that adds information about the style
125
  def interpret_prediction(top_style):
@@ -165,164 +178,81 @@ def interpret_prediction(top_style):
165
  else:
166
  return f"Information about {top_style} is not available."
167
 
168
- # Load the model once at startup
169
- model = load_model()
170
-
171
- # Custom CSS for styling
172
- custom_css = """
173
- .gradio-container {
174
- font-family: 'Source Sans Pro', sans-serif;
175
- max-width: 1200px !important;
176
- margin: auto;
177
- }
178
- .analyze-btn {
179
- height: 60px !important;
180
- font-size: 1.4rem !important;
181
- font-weight: 600 !important;
182
- background-color: #2563EB !important;
183
- }
184
- .title {
185
- font-size: 2.4rem !important;
186
- font-weight: 700 !important;
187
- text-align: center;
188
- margin-bottom: 1rem;
189
- background: linear-gradient(90deg, #2563EB 0%, #4F46E5 100%);
190
- -webkit-background-clip: text;
191
- -webkit-text-fill-color: transparent;
192
- }
193
- .subtitle {
194
- font-size: 1.3rem !important;
195
- text-align: center;
196
- margin-bottom: 2rem;
197
- }
198
- .image-display {
199
- min-height: 400px;
200
- border-radius: 12px;
201
- border: 2px solid #E5E7EB;
202
- box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
203
- }
204
- .info-output {
205
- font-size: 1.2rem !important;
206
- line-height: 1.6 !important;
207
- background-color: #F9FAFB;
208
- border-radius: 12px;
209
- padding: 1rem;
210
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
211
- }
212
- .examples-info {
213
- font-size: 1.3rem !important;
214
- line-height: 1.6 !important;
215
- }
216
- .examples-info h3 {
217
- font-size: 1.6rem !important;
218
- color: #1e40af;
219
- margin-bottom: 15px;
220
- }
221
- .examples-info li {
222
- margin-bottom: 10px;
223
- font-size: 1.3rem !important;
224
- }
225
- .gradio-container .examples-parent .examples-header {
226
- font-size: 1.5rem !important;
227
- font-weight: 600 !important;
228
- margin-bottom: 10px;
229
- }
230
- .gradio-container label,
231
- .gradio-container .label-wrap span,
232
- .gradio-container .examples-parent > div > p,
233
- .gradio-container .examples h4 {
234
- font-size: 1.6rem !important;
235
- font-weight: 600 !important;
236
- }
237
- .how-it-works {
238
- font-size: 1.3rem !important;
239
- line-height: 1.6 !important;
240
- }
241
- .how-it-works h3 {
242
- font-size: 1.6rem !important;
243
- color: #1e40af;
244
- margin-bottom: 15px;
245
- }
246
- .how-it-works ul {
247
- margin-top: 15px;
248
- margin-bottom: 15px;
249
- }
250
- .how-it-works li {
251
- margin-left: 20px;
252
- margin-bottom: 10px;
253
- }
254
- """
255
 
256
  # Set up the Gradio interface
257
- def launch_app():
258
- with gr.Blocks(css=custom_css) as app:
259
- gr.HTML("""
260
- <div>
261
- <h1 class="title">Art Style Classifier</h1>
262
- <p class="subtitle">Upload any artwork to identify its artistic style using AI</p>
263
- </div>
264
- """)
265
-
266
- with gr.Row():
267
- with gr.Column(scale=5):
268
- # Image input
269
- input_image = gr.Image(label="Upload Artwork", type="pil", elem_classes="image-display")
270
-
271
- # Analyze button
272
- analyze_btn = gr.Button("Analyze Artwork", elem_classes="analyze-btn")
273
-
274
- # Example images
275
- examples = gr.Examples(
276
- examples=[
277
- "examples/starry_night.jpg",
278
- "examples/mona_lisa.jpg",
279
- "examples/picasso.jpg",
280
- "examples/monet_water_lilies.jpg",
281
- "examples/kandinsky.jpg"
282
- ],
283
- inputs=input_image,
284
- label="Example Artworks",
285
- examples_per_page=5
286
- )
287
-
288
- # "How it works" section
289
- gr.HTML("""
290
- <div class="how-it-works">
291
- <h3>How It Works:</h3>
292
- <p>This application uses a deep learning model (ResNet34) trained on a dataset of art from various periods and styles.
293
- The model analyzes the visual characteristics of the uploaded image to identify its artistic style.</p>
294
- <ul>
295
- <li>The model was trained on over 8,000 paintings across 27 different artistic styles</li>
296
- <li>It achieves approximately 80% accuracy in classifying art styles</li>
297
- <li>Works best with complete paintings rather than details or cropped sections</li>
298
- </ul>
299
- </div>
300
- """)
301
 
302
- with gr.Column(scale=5):
303
- # Outputs
304
- prediction_output = gr.HTML(label="Prediction Results")
305
- style_info = gr.Markdown(label="Style Information")
306
-
307
- # Set up the prediction flow
308
- analyze_btn.click(
309
- fn=classify_image,
310
- inputs=[input_image],
311
- outputs=[prediction_output, style_info],
312
- ).then(
313
- fn=interpret_prediction,
314
- inputs=[style_info],
315
- outputs=[style_info]
316
- )
317
-
318
- input_image.change(
319
- fn=lambda: (None, None),
320
- inputs=[],
321
- outputs=[prediction_output, style_info]
322
- )
 
 
 
 
 
 
 
 
 
323
 
324
- # Launch the app
325
- app.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- if __name__ == "__main__":
328
- launch_app()
 
5
  from torchvision import models, transforms
6
  from PIL import Image
7
  import gradio as gr
 
8
  import numpy as np
 
9
 
10
+ # Check if model file exists and print paths for debugging
11
+ MODEL_PATH = "model_final.pth" # Model should be in root directory
12
+ if os.path.exists(MODEL_PATH):
13
+ print(f"Model found at {MODEL_PATH}")
14
  else:
15
+ print(f"Warning: Model not found at {MODEL_PATH}, current directory: {os.getcwd()}")
16
+ print(f"Files in current directory: {os.listdir('.')}")
17
 
18
+ # Device configuration
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Using device: {DEVICE}")
21
 
22
  # Art styles (sorted alphabetically for class index consistency)
23
  ART_STYLES = [
 
31
 
32
  # Image preprocessing
33
  def preprocess_image(image):
 
34
  transform = transforms.Compose([
35
  transforms.Resize(256),
36
  transforms.CenterCrop(224),
 
38
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
  ])
40
 
 
41
  image_tensor = transform(image).unsqueeze(0)
42
  return image_tensor
43
 
44
+ # Load model with error handling
45
  def load_model():
46
+ try:
47
+ # Create ResNet34 model
48
+ model = models.resnet34(weights=None)
49
+ # Adjust the final layer for our classes
50
+ model.fc = nn.Linear(512, len(ART_STYLES))
51
+
52
+ # Load the state dictionary with error handling
53
+ try:
54
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
55
+ model.load_state_dict(state_dict)
56
+ print("Model loaded successfully")
57
+ except Exception as e:
58
+ print(f"Error loading model state dict: {e}")
59
+ raise
60
+
61
+ model = model.to(DEVICE)
62
+ model.eval()
63
+
64
+ return model
65
+ except Exception as e:
66
+ print(f"Error in model loading: {e}")
67
+ raise
68
 
69
  # Function to predict art style
70
  def predict_art_style(image, model):
71
+ try:
72
+ # Preprocess the image
73
+ input_tensor = preprocess_image(image).to(DEVICE)
74
+
75
+ # Make prediction
76
+ with torch.no_grad():
77
+ outputs = model(input_tensor)
78
+ probabilities = F.softmax(outputs, dim=1)[0]
79
+
80
+ # Get top 5 predictions
81
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
82
+
83
+ # Create results
84
+ results = []
85
+ for i, (prob, idx) in enumerate(zip(top5_prob.cpu().numpy(), top5_indices.cpu().numpy())):
86
+ style = ART_STYLES[idx]
87
+ # Format style name for better display
88
+ display_style = style.replace('_', ' ')
89
+ results.append((display_style, float(prob), i == 0))
90
+
91
+ return results
92
+ except Exception as e:
93
+ print(f"Error in prediction: {e}")
94
+ return [("Error in prediction", 1.0, True)]
95
 
96
  # Main prediction function for Gradio
97
  def classify_image(image):
98
  if image is None:
99
+ return "Please upload an image to analyze.", ""
100
 
101
+ try:
102
+ # Convert from BGR to RGB (if needed)
103
+ if isinstance(image, np.ndarray):
104
+ image = Image.fromarray(image)
105
+
106
+ # Get model predictions
107
+ predictions = predict_art_style(image, model)
108
+
109
+ # Format predictions for display
110
+ result_html = "<div style='font-size: 1.2rem; background-color: #f0f9ff; padding: 1rem; border-radius: 8px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);'>"
111
+ result_html += "<h3 style='margin-bottom: 15px; color: #1e40af;'>Top 5 Predicted Art Styles:</h3>"
112
+
113
+ # Add prediction bars
114
+ for i, (style, prob, _) in enumerate(predictions):
115
+ percentage = prob * 100
116
+ bar_color = "#3b82f6" if i == 0 else "#93c5fd"
117
+ result_html += f"<div style='margin-bottom: 10px;'>"
118
+ result_html += f"<div style='display: flex; align-items: center; margin-bottom: 5px;'>"
119
+ result_html += f"<span style='font-weight: {'bold' if i==0 else 'normal'}; width: 200px; font-size: 1.1rem;'>{style}</span>"
120
+ result_html += f"<span style='margin-left: 10px; font-weight: {'bold' if i==0 else 'normal'}; width: 60px; text-align: right;'>{percentage:.1f}%</span>"
121
+ result_html += "</div>"
122
+ result_html += f"<div style='height: 10px; width: 100%; background-color: #e5e7eb; border-radius: 5px;'>"
123
+ result_html += f"<div style='height: 100%; width: {percentage}%; background-color: {bar_color}; border-radius: 5px;'></div>"
124
+ result_html += "</div>"
125
+ result_html += "</div>"
126
+
127
  result_html += "</div>"
128
+
129
+ # Get top prediction for style info
130
+ top_style = predictions[0][0]
131
+
132
+ return result_html, top_style
133
+ except Exception as e:
134
+ print(f"Error in classify_image: {e}")
135
+ return f"<div style='color: red;'>Error processing image: {str(e)}</div>", ""
136
 
137
  # Interpretation function that adds information about the style
138
  def interpret_prediction(top_style):
 
178
  else:
179
  return f"Information about {top_style} is not available."
180
 
181
+ # Try to load the model
182
+ try:
183
+ print("Loading model...")
184
+ model = load_model()
185
+ print("Model loaded successfully")
186
+ except Exception as e:
187
+ print(f"Failed to load model: {e}")
188
+ model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  # Set up the Gradio interface
191
+ with gr.Blocks() as app:
192
+ gr.HTML("""
193
+ <div style="text-align: center; margin-bottom: 1rem;">
194
+ <h1 style="font-size: 2.4rem; font-weight: 700; background: linear-gradient(90deg, #2563EB 0%, #4F46E5 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent;">Art Style Classifier</h1>
195
+ <p style="font-size: 1.3rem;">Upload any artwork to identify its artistic style using AI</p>
196
+ </div>
197
+ """)
198
+
199
+ with gr.Row():
200
+ with gr.Column(scale=5):
201
+ # Image input
202
+ input_image = gr.Image(label="Upload Artwork", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ # Analyze button
205
+ analyze_btn = gr.Button("Analyze Artwork", variant="primary")
206
+
207
+ # Example images
208
+ examples = gr.Examples(
209
+ examples=[
210
+ "examples/starry_night.jpg",
211
+ "examples/mona_lisa.jpg",
212
+ "examples/picasso.jpg",
213
+ "examples/monet_water_lilies.jpg",
214
+ "examples/kandinsky.jpg"
215
+ ],
216
+ inputs=input_image,
217
+ label="Example Artworks",
218
+ examples_per_page=5
219
+ )
220
+
221
+ # "How it works" section
222
+ gr.HTML("""
223
+ <div style="font-size: 1.1rem; line-height: 1.6; margin-top: 2rem;">
224
+ <h3 style="font-size: 1.4rem; color: #1e40af; margin-bottom: 0.8rem;">How It Works:</h3>
225
+ <p>This application uses a deep learning model (ResNet34) trained on a dataset of art from various periods and styles.
226
+ The model analyzes the visual characteristics of the uploaded image to identify its artistic style.</p>
227
+ <ul>
228
+ <li>The model was trained on over 50,000 paintings across 27 different artistic styles</li>
229
+ <li>It achieves approximately 74% accuracy in classifying art styles</li>
230
+ <li>Works best with complete paintings rather than details or cropped sections</li>
231
+ </ul>
232
+ </div>
233
+ """)
234
 
235
+ with gr.Column(scale=5):
236
+ # Outputs
237
+ prediction_output = gr.HTML(label="Prediction Results")
238
+ style_info = gr.Markdown(label="Style Information")
239
+
240
+ # Set up the prediction flow
241
+ analyze_btn.click(
242
+ fn=classify_image,
243
+ inputs=[input_image],
244
+ outputs=[prediction_output, style_info],
245
+ ).then(
246
+ fn=interpret_prediction,
247
+ inputs=[style_info],
248
+ outputs=[style_info]
249
+ )
250
+
251
+ input_image.change(
252
+ fn=lambda: (None, None),
253
+ inputs=[],
254
+ outputs=[prediction_output, style_info]
255
+ )
256
 
257
+ # Launch the app
258
+ app.launch()