Spaces:
Running
Running
File size: 13,435 Bytes
e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e e0ef430 a9b298e c8e5239 a9b298e e0ef430 a9b298e e0ef430 2045ca3 a9b298e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import numpy as np
# Check if model file exists and print paths for debugging
MODEL_PATH = "model_final.pth" # Model should be in root directory
if os.path.exists(MODEL_PATH):
print(f"Model found at {MODEL_PATH}")
else:
print(f"Warning: Model not found at {MODEL_PATH}, current directory: {os.getcwd()}")
print(f"Files in current directory: {os.listdir('.')}")
# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
# Art styles (sorted alphabetically for class index consistency)
ART_STYLES = [
'Abstract_Expressionism', 'Action_painting', 'Analytical_Cubism',
'Art_Nouveau_Modern', 'Baroque', 'Color_Field_Painting', 'Contemporary_Realism',
'Cubism', 'Early_Renaissance', 'Expressionism', 'Fauvism', 'High_Renaissance',
'Impressionism', 'Mannerism_Late_Renaissance', 'Minimalism', 'Naive_Art_Primitivism',
'New_Realism', 'Northern_Renaissance', 'Pointillism', 'Pop_Art', 'Post_Impressionism',
'Realism', 'Rococo', 'Romanticism', 'Symbolism', 'Synthetic_Cubism', 'Ukiyo_e'
]
# Image preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0)
return image_tensor
# Load model with error handling
def load_model():
try:
# Create ResNet34 model
model = models.resnet34(weights=None)
# Adjust the final layer for our classes
model.fc = nn.Linear(512, len(ART_STYLES))
# Load the state dictionary with error handling
try:
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model state dict: {e}")
raise
model = model.to(DEVICE)
model.eval()
return model
except Exception as e:
print(f"Error in model loading: {e}")
raise
# Function to predict art style
def predict_art_style(image, model):
try:
# Preprocess the image
input_tensor = preprocess_image(image).to(DEVICE)
# Make prediction
with torch.no_grad():
outputs = model(input_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Create results
results = []
for i, (prob, idx) in enumerate(zip(top5_prob.cpu().numpy(), top5_indices.cpu().numpy())):
style = ART_STYLES[idx]
# Format style name for better display
display_style = style.replace('_', ' ')
results.append((display_style, float(prob), i == 0))
return results
except Exception as e:
print(f"Error in prediction: {e}")
return [("Error in prediction", 1.0, True)]
# Main prediction function for Gradio
def classify_image(image):
if image is None:
return "Please upload an image to analyze.", ""
try:
# Convert from BGR to RGB (if needed)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Get model predictions
predictions = predict_art_style(image, model)
# Format predictions for display
result_html = "<div style='font-size: 1.2rem; background-color: #f0f9ff; padding: 1rem; border-radius: 8px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);'>"
result_html += "<h3 style='margin-bottom: 15px; color: #1e40af;'>Top 5 Predicted Art Styles:</h3>"
# Add prediction bars
for i, (style, prob, _) in enumerate(predictions):
percentage = prob * 100
bar_color = "#3b82f6" if i == 0 else "#93c5fd"
result_html += f"<div style='margin-bottom: 10px;'>"
result_html += f"<div style='display: flex; align-items: center; margin-bottom: 5px;'>"
result_html += f"<span style='font-weight: {'bold' if i==0 else 'normal'}; width: 200px; font-size: 1.1rem;'>{style}</span>"
result_html += f"<span style='margin-left: 10px; font-weight: {'bold' if i==0 else 'normal'}; width: 60px; text-align: right;'>{percentage:.1f}%</span>"
result_html += "</div>"
result_html += f"<div style='height: 10px; width: 100%; background-color: #e5e7eb; border-radius: 5px;'>"
result_html += f"<div style='height: 100%; width: {percentage}%; background-color: {bar_color}; border-radius: 5px;'></div>"
result_html += "</div>"
result_html += "</div>"
result_html += "</div>"
# Get top prediction for style info
top_style = predictions[0][0]
return result_html, top_style
except Exception as e:
print(f"Error in classify_image: {e}")
return f"<div style='color: red;'>Error processing image: {str(e)}</div>", ""
# Interpretation function that adds information about the style
def interpret_prediction(top_style):
if not top_style:
return "Please upload an image to analyze."
# Style descriptions
style_info = {
'Abstract Expressionism': "Abstract Expressionism is characterized by gestural brush-strokes or mark-making, and the impression of spontaneity. Key artists include Jackson Pollock and Willem de Kooning.",
'Action painting': "Action Painting, a subset of Abstract Expressionism, emphasizes the physical act of painting itself. The canvas was seen as an arena in which to act.",
'Analytical Cubism': "Analytical Cubism is characterized by geometric shapes, fragmented forms, and a monochromatic palette. Pioneered by Pablo Picasso and Georges Braque.",
'Art Nouveau Modern': "Art Nouveau features highly stylized, flowing curvilinear designs, often incorporating floral and other plant-inspired motifs.",
'Baroque': "Baroque art is characterized by drama, rich color, and intense light and shadow. Notable for its grandeur and ornate details.",
'Color Field Painting': "Color Field Painting is characterized by large areas of a more or less flat single color. Key artists include Mark Rothko and Clyfford Still.",
'Contemporary Realism': "Contemporary Realism emerged as a counterbalance to Abstract Expressionism, representing subject matter in a straightforward way.",
'Cubism': "Cubism revolutionized European painting by depicting subjects from multiple viewpoints simultaneously, creating a greater context of perception.",
'Early Renaissance': "Early Renaissance art marks the transition from Medieval to Renaissance art, with increased realism and perspective. Notable artists include Donatello and Masaccio.",
'Expressionism': "Expressionism distorts reality for emotional effect, presenting the world solely from a subjective perspective.",
'Fauvism': "Fauvism is characterized by strong, vibrant colors and wild brushwork. Led by Henri Matisse and André Derain.",
'High Renaissance': "The High Renaissance represents the pinnacle of Renaissance art, with perfect harmony and balance. Key figures include Leonardo da Vinci, Michelangelo, and Raphael.",
'Impressionism': "Impressionism captures the momentary, sensory effect of a scene rather than exact details. Famous artists include Claude Monet and Pierre-Auguste Renoir.",
'Mannerism Late Renaissance': "Mannerism exaggerates proportions and balance, with artificial qualities replacing naturalistic ones. Emerged after the High Renaissance.",
'Minimalism': "Minimalism uses simple elements, focusing on objectivity and emphasizing the materials. Notable for its extreme simplicity and formal precision.",
'Naive Art Primitivism': "Naive Art is characterized by simplicity, lack of perspective, and childlike execution. Often created by untrained artists.",
'New Realism': "New Realism appropriates parts of reality, incorporating actual physical fragments of reality or objects as the artworks themselves.",
'Northern Renaissance': "Northern Renaissance art is known for its precise details, symbolism, and advanced oil painting techniques. Key figures include Jan van Eyck and Albrecht Dürer.",
'Pointillism': "Pointillism technique uses small, distinct dots of color applied in patterns to form an image. Developed by Georges Seurat and Paul Signac.",
'Pop Art': "Pop Art uses imagery from popular culture like advertising and news. Famous artists include Andy Warhol and Roy Lichtenstein.",
'Post Impressionism': "Post Impressionism extended Impressionism while rejecting its limitations. Key figures include Vincent van Gogh, Paul Cézanne, and Paul Gauguin.",
'Realism': "Realism depicts subjects as they appear in everyday life, without embellishment or interpretation. Emerged in the mid-19th century.",
'Rococo': "Rococo art is characterized by ornate decoration, pastel colors, and asymmetrical designs. Popular in the 18th century.",
'Romanticism': "Romanticism emphasizes emotion, individualism, and glorification of nature and the past. Emerged in the late 18th century.",
'Symbolism': "Symbolism uses symbolic imagery to express mystical ideas, emotions, and states of mind. Emerged in the late 19th century.",
'Synthetic Cubism': "Synthetic Cubism is the second phase of Cubism, incorporating collage elements and a broader range of textures and colors.",
'Ukiyo e': "Ukiyo-e are Japanese woodblock prints depicting landscapes, tales from history, and scenes from everyday life. Popular during the Edo period."
}
# Find the matching key (handling spaces vs. underscores)
matching_key = next((k for k in style_info.keys() if k.replace(' ', '') == top_style.replace(' ', '')), None)
if matching_key:
return style_info[matching_key]
else:
return f"Information about {top_style} is not available."
# Try to load the model
try:
print("Loading model...")
model = load_model()
print("Model loaded successfully")
except Exception as e:
print(f"Failed to load model: {e}")
model = None
# Set up the Gradio interface
with gr.Blocks() as app:
gr.HTML("""
<div style="text-align: center; margin-bottom: 1rem;">
<h1 style="font-size: 2.4rem; font-weight: 700; background: linear-gradient(90deg, #2563EB 0%, #4F46E5 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent;">Art Style Classifier</h1>
<p style="font-size: 1.3rem;">Upload any artwork to identify its artistic style using AI</p>
</div>
""")
with gr.Row():
with gr.Column(scale=5):
# Image input
input_image = gr.Image(label="Upload Artwork", type="pil")
# Analyze button
analyze_btn = gr.Button("Analyze Artwork", variant="primary")
# Example images
examples = gr.Examples(
examples=[
"examples/starry_night.jpg",
"examples/mona_lisa.jpg",
"examples/les_demoiselles.jpg",
"examples/the_scream.jpg",
"examples/impression_sunrise.jpg"
],
inputs=input_image,
label="Example Artworks",
examples_per_page=5
)
# "How it works" section
gr.HTML("""
<div style="font-size: 1.1rem; line-height: 1.6; margin-top: 2rem;">
<h3 style="font-size: 1.4rem; color: #1e40af; margin-bottom: 0.8rem;">How It Works:</h3>
<p>This application uses a deep learning model (ResNet34) trained on a dataset of art from various periods and styles.
The model analyzes the visual characteristics of the uploaded image to identify its artistic style.</p>
<ul>
<li>The model was trained on over 50,000 paintings across 27 different artistic styles</li>
<li>It achieves approximately 74% accuracy in classifying art styles</li>
<li>Works best with complete paintings rather than details or cropped sections</li>
</ul>
</div>
""")
with gr.Column(scale=5):
# Outputs
prediction_output = gr.HTML(label="Prediction Results")
style_info = gr.Markdown(label="Style Information")
# Set up the prediction flow
analyze_btn.click(
fn=classify_image,
inputs=[input_image],
outputs=[prediction_output, style_info],
).then(
fn=interpret_prediction,
inputs=[style_info],
outputs=[style_info]
)
input_image.change(
fn=lambda: (None, None),
inputs=[],
outputs=[prediction_output, style_info]
)
# Launch the application
app.launch() |