Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| # Check if model file exists and print paths for debugging | |
| MODEL_PATH = "model_final.pth" # Model should be in root directory | |
| if os.path.exists(MODEL_PATH): | |
| print(f"Model found at {MODEL_PATH}") | |
| else: | |
| print(f"Warning: Model not found at {MODEL_PATH}, current directory: {os.getcwd()}") | |
| print(f"Files in current directory: {os.listdir('.')}") | |
| # Device configuration | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {DEVICE}") | |
| # Art styles (sorted alphabetically for class index consistency) | |
| ART_STYLES = [ | |
| 'Abstract_Expressionism', 'Action_painting', 'Analytical_Cubism', | |
| 'Art_Nouveau_Modern', 'Baroque', 'Color_Field_Painting', 'Contemporary_Realism', | |
| 'Cubism', 'Early_Renaissance', 'Expressionism', 'Fauvism', 'High_Renaissance', | |
| 'Impressionism', 'Mannerism_Late_Renaissance', 'Minimalism', 'Naive_Art_Primitivism', | |
| 'New_Realism', 'Northern_Renaissance', 'Pointillism', 'Pop_Art', 'Post_Impressionism', | |
| 'Realism', 'Rococo', 'Romanticism', 'Symbolism', 'Synthetic_Cubism', 'Ukiyo_e' | |
| ] | |
| # Image preprocessing | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0) | |
| return image_tensor | |
| # Load model with error handling | |
| def load_model(): | |
| try: | |
| # Create ResNet34 model | |
| model = models.resnet34(weights=None) | |
| # Adjust the final layer for our classes | |
| model.fc = nn.Linear(512, len(ART_STYLES)) | |
| # Load the state dictionary with error handling | |
| try: | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model state dict: {e}") | |
| raise | |
| model = model.to(DEVICE) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error in model loading: {e}") | |
| raise | |
| # Function to predict art style | |
| def predict_art_style(image, model): | |
| try: | |
| # Preprocess the image | |
| input_tensor = preprocess_image(image).to(DEVICE) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = F.softmax(outputs, dim=1)[0] | |
| # Get top 5 predictions | |
| top5_prob, top5_indices = torch.topk(probabilities, 5) | |
| # Create results | |
| results = [] | |
| for i, (prob, idx) in enumerate(zip(top5_prob.cpu().numpy(), top5_indices.cpu().numpy())): | |
| style = ART_STYLES[idx] | |
| # Format style name for better display | |
| display_style = style.replace('_', ' ') | |
| results.append((display_style, float(prob), i == 0)) | |
| return results | |
| except Exception as e: | |
| print(f"Error in prediction: {e}") | |
| return [("Error in prediction", 1.0, True)] | |
| # Main prediction function for Gradio | |
| def classify_image(image): | |
| if image is None: | |
| return "Please upload an image to analyze.", "" | |
| try: | |
| # Convert from BGR to RGB (if needed) | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Get model predictions | |
| predictions = predict_art_style(image, model) | |
| # Format predictions for display | |
| result_html = "<div style='font-size: 1.2rem; background-color: #f0f9ff; padding: 1rem; border-radius: 8px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);'>" | |
| result_html += "<h3 style='margin-bottom: 15px; color: #1e40af;'>Top 5 Predicted Art Styles:</h3>" | |
| # Add prediction bars | |
| for i, (style, prob, _) in enumerate(predictions): | |
| percentage = prob * 100 | |
| bar_color = "#3b82f6" if i == 0 else "#93c5fd" | |
| result_html += f"<div style='margin-bottom: 10px;'>" | |
| result_html += f"<div style='display: flex; align-items: center; margin-bottom: 5px;'>" | |
| result_html += f"<span style='font-weight: {'bold' if i==0 else 'normal'}; width: 200px; font-size: 1.1rem;'>{style}</span>" | |
| result_html += f"<span style='margin-left: 10px; font-weight: {'bold' if i==0 else 'normal'}; width: 60px; text-align: right;'>{percentage:.1f}%</span>" | |
| result_html += "</div>" | |
| result_html += f"<div style='height: 10px; width: 100%; background-color: #e5e7eb; border-radius: 5px;'>" | |
| result_html += f"<div style='height: 100%; width: {percentage}%; background-color: {bar_color}; border-radius: 5px;'></div>" | |
| result_html += "</div>" | |
| result_html += "</div>" | |
| result_html += "</div>" | |
| # Get top prediction for style info | |
| top_style = predictions[0][0] | |
| return result_html, top_style | |
| except Exception as e: | |
| print(f"Error in classify_image: {e}") | |
| return f"<div style='color: red;'>Error processing image: {str(e)}</div>", "" | |
| # Interpretation function that adds information about the style | |
| def interpret_prediction(top_style): | |
| if not top_style: | |
| return "Please upload an image to analyze." | |
| # Style descriptions | |
| style_info = { | |
| 'Abstract Expressionism': "Abstract Expressionism is characterized by gestural brush-strokes or mark-making, and the impression of spontaneity. Key artists include Jackson Pollock and Willem de Kooning.", | |
| 'Action painting': "Action Painting, a subset of Abstract Expressionism, emphasizes the physical act of painting itself. The canvas was seen as an arena in which to act.", | |
| 'Analytical Cubism': "Analytical Cubism is characterized by geometric shapes, fragmented forms, and a monochromatic palette. Pioneered by Pablo Picasso and Georges Braque.", | |
| 'Art Nouveau Modern': "Art Nouveau features highly stylized, flowing curvilinear designs, often incorporating floral and other plant-inspired motifs.", | |
| 'Baroque': "Baroque art is characterized by drama, rich color, and intense light and shadow. Notable for its grandeur and ornate details.", | |
| 'Color Field Painting': "Color Field Painting is characterized by large areas of a more or less flat single color. Key artists include Mark Rothko and Clyfford Still.", | |
| 'Contemporary Realism': "Contemporary Realism emerged as a counterbalance to Abstract Expressionism, representing subject matter in a straightforward way.", | |
| 'Cubism': "Cubism revolutionized European painting by depicting subjects from multiple viewpoints simultaneously, creating a greater context of perception.", | |
| 'Early Renaissance': "Early Renaissance art marks the transition from Medieval to Renaissance art, with increased realism and perspective. Notable artists include Donatello and Masaccio.", | |
| 'Expressionism': "Expressionism distorts reality for emotional effect, presenting the world solely from a subjective perspective.", | |
| 'Fauvism': "Fauvism is characterized by strong, vibrant colors and wild brushwork. Led by Henri Matisse and André Derain.", | |
| 'High Renaissance': "The High Renaissance represents the pinnacle of Renaissance art, with perfect harmony and balance. Key figures include Leonardo da Vinci, Michelangelo, and Raphael.", | |
| 'Impressionism': "Impressionism captures the momentary, sensory effect of a scene rather than exact details. Famous artists include Claude Monet and Pierre-Auguste Renoir.", | |
| 'Mannerism Late Renaissance': "Mannerism exaggerates proportions and balance, with artificial qualities replacing naturalistic ones. Emerged after the High Renaissance.", | |
| 'Minimalism': "Minimalism uses simple elements, focusing on objectivity and emphasizing the materials. Notable for its extreme simplicity and formal precision.", | |
| 'Naive Art Primitivism': "Naive Art is characterized by simplicity, lack of perspective, and childlike execution. Often created by untrained artists.", | |
| 'New Realism': "New Realism appropriates parts of reality, incorporating actual physical fragments of reality or objects as the artworks themselves.", | |
| 'Northern Renaissance': "Northern Renaissance art is known for its precise details, symbolism, and advanced oil painting techniques. Key figures include Jan van Eyck and Albrecht Dürer.", | |
| 'Pointillism': "Pointillism technique uses small, distinct dots of color applied in patterns to form an image. Developed by Georges Seurat and Paul Signac.", | |
| 'Pop Art': "Pop Art uses imagery from popular culture like advertising and news. Famous artists include Andy Warhol and Roy Lichtenstein.", | |
| 'Post Impressionism': "Post Impressionism extended Impressionism while rejecting its limitations. Key figures include Vincent van Gogh, Paul Cézanne, and Paul Gauguin.", | |
| 'Realism': "Realism depicts subjects as they appear in everyday life, without embellishment or interpretation. Emerged in the mid-19th century.", | |
| 'Rococo': "Rococo art is characterized by ornate decoration, pastel colors, and asymmetrical designs. Popular in the 18th century.", | |
| 'Romanticism': "Romanticism emphasizes emotion, individualism, and glorification of nature and the past. Emerged in the late 18th century.", | |
| 'Symbolism': "Symbolism uses symbolic imagery to express mystical ideas, emotions, and states of mind. Emerged in the late 19th century.", | |
| 'Synthetic Cubism': "Synthetic Cubism is the second phase of Cubism, incorporating collage elements and a broader range of textures and colors.", | |
| 'Ukiyo e': "Ukiyo-e are Japanese woodblock prints depicting landscapes, tales from history, and scenes from everyday life. Popular during the Edo period." | |
| } | |
| # Find the matching key (handling spaces vs. underscores) | |
| matching_key = next((k for k in style_info.keys() if k.replace(' ', '') == top_style.replace(' ', '')), None) | |
| if matching_key: | |
| return style_info[matching_key] | |
| else: | |
| return f"Information about {top_style} is not available." | |
| # Try to load the model | |
| try: | |
| print("Loading model...") | |
| model = load_model() | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| model = None | |
| # Set up the Gradio interface | |
| with gr.Blocks() as app: | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 1rem;"> | |
| <h1 style="font-size: 2.4rem; font-weight: 700; background: linear-gradient(90deg, #2563EB 0%, #4F46E5 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent;">Art Style Classifier</h1> | |
| <p style="font-size: 1.3rem;">Upload any artwork to identify its artistic style using AI</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| # Image input | |
| input_image = gr.Image(label="Upload Artwork", type="pil") | |
| # Analyze button | |
| analyze_btn = gr.Button("Analyze Artwork", variant="primary") | |
| # Example images | |
| examples = gr.Examples( | |
| examples=[ | |
| "examples/starry_night.jpg", | |
| "examples/mona_lisa.jpg", | |
| "examples/les_demoiselles.jpg", | |
| "examples/the_scream.jpg", | |
| "examples/impression_sunrise.jpg" | |
| ], | |
| inputs=input_image, | |
| label="Example Artworks", | |
| examples_per_page=5 | |
| ) | |
| # "How it works" section | |
| gr.HTML(""" | |
| <div style="font-size: 1.1rem; line-height: 1.6; margin-top: 2rem;"> | |
| <h3 style="font-size: 1.4rem; color: #1e40af; margin-bottom: 0.8rem;">How It Works:</h3> | |
| <p>This application uses a deep learning model (ResNet34) trained on a dataset of art from various periods and styles. | |
| The model analyzes the visual characteristics of the uploaded image to identify its artistic style.</p> | |
| <ul> | |
| <li>The model was trained on over 50,000 paintings across 27 different artistic styles</li> | |
| <li>It achieves approximately 74% accuracy in classifying art styles</li> | |
| <li>Works best with complete paintings rather than details or cropped sections</li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Column(scale=5): | |
| # Outputs | |
| prediction_output = gr.HTML(label="Prediction Results") | |
| style_info = gr.Markdown(label="Style Information") | |
| # Set up the prediction flow | |
| analyze_btn.click( | |
| fn=classify_image, | |
| inputs=[input_image], | |
| outputs=[prediction_output, style_info], | |
| ).then( | |
| fn=interpret_prediction, | |
| inputs=[style_info], | |
| outputs=[style_info] | |
| ) | |
| input_image.change( | |
| fn=lambda: (None, None), | |
| inputs=[], | |
| outputs=[prediction_output, style_info] | |
| ) | |
| # Launch the application | |
| app.launch() |