ErdemAtak's picture
Update app.py
c8e5239 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import numpy as np
# Check if model file exists and print paths for debugging
MODEL_PATH = "model_final.pth" # Model should be in root directory
if os.path.exists(MODEL_PATH):
print(f"Model found at {MODEL_PATH}")
else:
print(f"Warning: Model not found at {MODEL_PATH}, current directory: {os.getcwd()}")
print(f"Files in current directory: {os.listdir('.')}")
# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
# Art styles (sorted alphabetically for class index consistency)
ART_STYLES = [
'Abstract_Expressionism', 'Action_painting', 'Analytical_Cubism',
'Art_Nouveau_Modern', 'Baroque', 'Color_Field_Painting', 'Contemporary_Realism',
'Cubism', 'Early_Renaissance', 'Expressionism', 'Fauvism', 'High_Renaissance',
'Impressionism', 'Mannerism_Late_Renaissance', 'Minimalism', 'Naive_Art_Primitivism',
'New_Realism', 'Northern_Renaissance', 'Pointillism', 'Pop_Art', 'Post_Impressionism',
'Realism', 'Rococo', 'Romanticism', 'Symbolism', 'Synthetic_Cubism', 'Ukiyo_e'
]
# Image preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0)
return image_tensor
# Load model with error handling
def load_model():
try:
# Create ResNet34 model
model = models.resnet34(weights=None)
# Adjust the final layer for our classes
model.fc = nn.Linear(512, len(ART_STYLES))
# Load the state dictionary with error handling
try:
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model state dict: {e}")
raise
model = model.to(DEVICE)
model.eval()
return model
except Exception as e:
print(f"Error in model loading: {e}")
raise
# Function to predict art style
def predict_art_style(image, model):
try:
# Preprocess the image
input_tensor = preprocess_image(image).to(DEVICE)
# Make prediction
with torch.no_grad():
outputs = model(input_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Create results
results = []
for i, (prob, idx) in enumerate(zip(top5_prob.cpu().numpy(), top5_indices.cpu().numpy())):
style = ART_STYLES[idx]
# Format style name for better display
display_style = style.replace('_', ' ')
results.append((display_style, float(prob), i == 0))
return results
except Exception as e:
print(f"Error in prediction: {e}")
return [("Error in prediction", 1.0, True)]
# Main prediction function for Gradio
def classify_image(image):
if image is None:
return "Please upload an image to analyze.", ""
try:
# Convert from BGR to RGB (if needed)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Get model predictions
predictions = predict_art_style(image, model)
# Format predictions for display
result_html = "<div style='font-size: 1.2rem; background-color: #f0f9ff; padding: 1rem; border-radius: 8px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);'>"
result_html += "<h3 style='margin-bottom: 15px; color: #1e40af;'>Top 5 Predicted Art Styles:</h3>"
# Add prediction bars
for i, (style, prob, _) in enumerate(predictions):
percentage = prob * 100
bar_color = "#3b82f6" if i == 0 else "#93c5fd"
result_html += f"<div style='margin-bottom: 10px;'>"
result_html += f"<div style='display: flex; align-items: center; margin-bottom: 5px;'>"
result_html += f"<span style='font-weight: {'bold' if i==0 else 'normal'}; width: 200px; font-size: 1.1rem;'>{style}</span>"
result_html += f"<span style='margin-left: 10px; font-weight: {'bold' if i==0 else 'normal'}; width: 60px; text-align: right;'>{percentage:.1f}%</span>"
result_html += "</div>"
result_html += f"<div style='height: 10px; width: 100%; background-color: #e5e7eb; border-radius: 5px;'>"
result_html += f"<div style='height: 100%; width: {percentage}%; background-color: {bar_color}; border-radius: 5px;'></div>"
result_html += "</div>"
result_html += "</div>"
result_html += "</div>"
# Get top prediction for style info
top_style = predictions[0][0]
return result_html, top_style
except Exception as e:
print(f"Error in classify_image: {e}")
return f"<div style='color: red;'>Error processing image: {str(e)}</div>", ""
# Interpretation function that adds information about the style
def interpret_prediction(top_style):
if not top_style:
return "Please upload an image to analyze."
# Style descriptions
style_info = {
'Abstract Expressionism': "Abstract Expressionism is characterized by gestural brush-strokes or mark-making, and the impression of spontaneity. Key artists include Jackson Pollock and Willem de Kooning.",
'Action painting': "Action Painting, a subset of Abstract Expressionism, emphasizes the physical act of painting itself. The canvas was seen as an arena in which to act.",
'Analytical Cubism': "Analytical Cubism is characterized by geometric shapes, fragmented forms, and a monochromatic palette. Pioneered by Pablo Picasso and Georges Braque.",
'Art Nouveau Modern': "Art Nouveau features highly stylized, flowing curvilinear designs, often incorporating floral and other plant-inspired motifs.",
'Baroque': "Baroque art is characterized by drama, rich color, and intense light and shadow. Notable for its grandeur and ornate details.",
'Color Field Painting': "Color Field Painting is characterized by large areas of a more or less flat single color. Key artists include Mark Rothko and Clyfford Still.",
'Contemporary Realism': "Contemporary Realism emerged as a counterbalance to Abstract Expressionism, representing subject matter in a straightforward way.",
'Cubism': "Cubism revolutionized European painting by depicting subjects from multiple viewpoints simultaneously, creating a greater context of perception.",
'Early Renaissance': "Early Renaissance art marks the transition from Medieval to Renaissance art, with increased realism and perspective. Notable artists include Donatello and Masaccio.",
'Expressionism': "Expressionism distorts reality for emotional effect, presenting the world solely from a subjective perspective.",
'Fauvism': "Fauvism is characterized by strong, vibrant colors and wild brushwork. Led by Henri Matisse and André Derain.",
'High Renaissance': "The High Renaissance represents the pinnacle of Renaissance art, with perfect harmony and balance. Key figures include Leonardo da Vinci, Michelangelo, and Raphael.",
'Impressionism': "Impressionism captures the momentary, sensory effect of a scene rather than exact details. Famous artists include Claude Monet and Pierre-Auguste Renoir.",
'Mannerism Late Renaissance': "Mannerism exaggerates proportions and balance, with artificial qualities replacing naturalistic ones. Emerged after the High Renaissance.",
'Minimalism': "Minimalism uses simple elements, focusing on objectivity and emphasizing the materials. Notable for its extreme simplicity and formal precision.",
'Naive Art Primitivism': "Naive Art is characterized by simplicity, lack of perspective, and childlike execution. Often created by untrained artists.",
'New Realism': "New Realism appropriates parts of reality, incorporating actual physical fragments of reality or objects as the artworks themselves.",
'Northern Renaissance': "Northern Renaissance art is known for its precise details, symbolism, and advanced oil painting techniques. Key figures include Jan van Eyck and Albrecht Dürer.",
'Pointillism': "Pointillism technique uses small, distinct dots of color applied in patterns to form an image. Developed by Georges Seurat and Paul Signac.",
'Pop Art': "Pop Art uses imagery from popular culture like advertising and news. Famous artists include Andy Warhol and Roy Lichtenstein.",
'Post Impressionism': "Post Impressionism extended Impressionism while rejecting its limitations. Key figures include Vincent van Gogh, Paul Cézanne, and Paul Gauguin.",
'Realism': "Realism depicts subjects as they appear in everyday life, without embellishment or interpretation. Emerged in the mid-19th century.",
'Rococo': "Rococo art is characterized by ornate decoration, pastel colors, and asymmetrical designs. Popular in the 18th century.",
'Romanticism': "Romanticism emphasizes emotion, individualism, and glorification of nature and the past. Emerged in the late 18th century.",
'Symbolism': "Symbolism uses symbolic imagery to express mystical ideas, emotions, and states of mind. Emerged in the late 19th century.",
'Synthetic Cubism': "Synthetic Cubism is the second phase of Cubism, incorporating collage elements and a broader range of textures and colors.",
'Ukiyo e': "Ukiyo-e are Japanese woodblock prints depicting landscapes, tales from history, and scenes from everyday life. Popular during the Edo period."
}
# Find the matching key (handling spaces vs. underscores)
matching_key = next((k for k in style_info.keys() if k.replace(' ', '') == top_style.replace(' ', '')), None)
if matching_key:
return style_info[matching_key]
else:
return f"Information about {top_style} is not available."
# Try to load the model
try:
print("Loading model...")
model = load_model()
print("Model loaded successfully")
except Exception as e:
print(f"Failed to load model: {e}")
model = None
# Set up the Gradio interface
with gr.Blocks() as app:
gr.HTML("""
<div style="text-align: center; margin-bottom: 1rem;">
<h1 style="font-size: 2.4rem; font-weight: 700; background: linear-gradient(90deg, #2563EB 0%, #4F46E5 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent;">Art Style Classifier</h1>
<p style="font-size: 1.3rem;">Upload any artwork to identify its artistic style using AI</p>
</div>
""")
with gr.Row():
with gr.Column(scale=5):
# Image input
input_image = gr.Image(label="Upload Artwork", type="pil")
# Analyze button
analyze_btn = gr.Button("Analyze Artwork", variant="primary")
# Example images
examples = gr.Examples(
examples=[
"examples/starry_night.jpg",
"examples/mona_lisa.jpg",
"examples/les_demoiselles.jpg",
"examples/the_scream.jpg",
"examples/impression_sunrise.jpg"
],
inputs=input_image,
label="Example Artworks",
examples_per_page=5
)
# "How it works" section
gr.HTML("""
<div style="font-size: 1.1rem; line-height: 1.6; margin-top: 2rem;">
<h3 style="font-size: 1.4rem; color: #1e40af; margin-bottom: 0.8rem;">How It Works:</h3>
<p>This application uses a deep learning model (ResNet34) trained on a dataset of art from various periods and styles.
The model analyzes the visual characteristics of the uploaded image to identify its artistic style.</p>
<ul>
<li>The model was trained on over 50,000 paintings across 27 different artistic styles</li>
<li>It achieves approximately 74% accuracy in classifying art styles</li>
<li>Works best with complete paintings rather than details or cropped sections</li>
</ul>
</div>
""")
with gr.Column(scale=5):
# Outputs
prediction_output = gr.HTML(label="Prediction Results")
style_info = gr.Markdown(label="Style Information")
# Set up the prediction flow
analyze_btn.click(
fn=classify_image,
inputs=[input_image],
outputs=[prediction_output, style_info],
).then(
fn=interpret_prediction,
inputs=[style_info],
outputs=[style_info]
)
input_image.change(
fn=lambda: (None, None),
inputs=[],
outputs=[prediction_output, style_info]
)
# Launch the application
app.launch()