Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import timm | |
| from PIL import Image | |
| from torchvision import transforms | |
| import os | |
| import requests | |
| import time | |
| from typing import Dict, Tuple, Optional | |
| from io import BytesIO | |
| # ---------------------------- | |
| # 1. Configuration | |
| # ---------------------------- | |
| MODEL_ARCH = "animetimm/caformer_b36.dbv4-full" | |
| CLASSES = ["Good", "Normal", "Bad"] # Must match your training order | |
| MODEL_PATH = "best_model.pth" | |
| MODEL_URL = "https://huggingface.co/Shio-Koube/ConvNext-aesthetic-rater/resolve/main/best_model.pth" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Fixed image size for inference | |
| IMAGE_SIZE = 448 | |
| MIN_IMAGE_SIZE = 32 | |
| MAX_FILE_SIZE_MB = 10 | |
| # Example images for users to try | |
| EXAMPLE_IMAGE_URLS = [ | |
| "https://cdn.donmai.us/sample/06/5d/__mayor_maybelle_eiyuu_densetsu_and_1_more_drawn_by_tinybiard__sample-065de2c8baabf8a25dacd181e14ce900.jpg", | |
| "https://cdn.donmai.us/sample/b6/d1/__nicole_demara_and_aria_zenless_zone_zero_drawn_by_orcavice__sample-b6d15d4892fb58e06943692736483799.jpg", | |
| "https://cdn.donmai.us/sample/0a/74/__tarte_and_macaron_kemono_teatime_drawn_by_mamesuzu__sample-0a7476e8f1672a9a95af9faf4a133326.jpg", | |
| "https://cdn.donmai.us/sample/86/ee/__quiche_kemono_teatime_drawn_by_ntny__sample-86ee7123cd43a19faaa628b71662bad8.jpg", | |
| "https://cdn.donmai.us/sample/11/43/__lance_crown_mashle_drawn_by_knata09660180__sample-11431fded8a2d4f966c1caa0ec5d515e.jpg", | |
| "https://cdn.donmai.us/sample/76/0a/__aeria_original_drawn_by_eudetenis__sample-760a0cb3a499c002acbbdd50db899722.jpg", | |
| ] | |
| EXAMPLE_DIR = "example_images" | |
| EXAMPLE_IMAGES = [] | |
| def download_example_images(): | |
| """Download example images with 1 second delay between each to avoid rate limiting""" | |
| global EXAMPLE_IMAGES | |
| if not os.path.exists(EXAMPLE_DIR): | |
| os.makedirs(EXAMPLE_DIR) | |
| print("Downloading example images...") | |
| for i, url in enumerate(EXAMPLE_IMAGE_URLS, 1): | |
| filename = f"example_{i}.jpg" | |
| filepath = os.path.join(EXAMPLE_DIR, filename) | |
| # Skip if already downloaded | |
| if os.path.exists(filepath): | |
| print(f"Example {i}/{len(EXAMPLE_IMAGE_URLS)}: Already exists") | |
| EXAMPLE_IMAGES.append(filepath) | |
| continue | |
| try: | |
| # Sleep before downloading (except for the first one) | |
| if i > 1: | |
| print(f"Waiting 1 second before downloading example {i}...") | |
| time.sleep(1) | |
| print(f"Downloading example {i}/{len(EXAMPLE_IMAGE_URLS)}...") | |
| response = requests.get(url, timeout=30) | |
| response.raise_for_status() | |
| # Save the image | |
| img = Image.open(BytesIO(response.content)) | |
| img.save(filepath) | |
| EXAMPLE_IMAGES.append(filepath) | |
| print(f"✓ Example {i} downloaded") | |
| except Exception as e: | |
| print(f"✗ Failed to download example {i}: {e}") | |
| # Continue with other images even if one fails | |
| print(f"Downloaded {len(EXAMPLE_IMAGES)}/{len(EXAMPLE_IMAGE_URLS)} example images.") | |
| # Normalization (same as training) | |
| normalize = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ---------------------------- | |
| # 2. Resize to Fixed 448x448 | |
| # ---------------------------- | |
| def resize_to_448(image: Image.Image) -> Tuple[Image.Image, str]: | |
| """ | |
| Resize image to exactly 448x448 pixels. | |
| Returns: | |
| Tuple of (resized_image, info_string) | |
| """ | |
| original_w, original_h = image.size | |
| # Resize directly to 448x448 | |
| resized = image.resize((IMAGE_SIZE, IMAGE_SIZE), Image.BICUBIC) | |
| info = f"Original: {original_w}×{original_h} → Resized: {IMAGE_SIZE}×{IMAGE_SIZE}" | |
| return resized, info | |
| def validate_image(image: Image.Image) -> None: | |
| """Validate image meets minimum requirements""" | |
| if image is None: | |
| raise ValueError("No image provided") | |
| w, h = image.size | |
| if w < MIN_IMAGE_SIZE or h < MIN_IMAGE_SIZE: | |
| raise ValueError(f"Image too small (minimum {MIN_IMAGE_SIZE}×{MIN_IMAGE_SIZE})") | |
| # ---------------------------- | |
| # 3. Setup Model | |
| # ---------------------------- | |
| def download_model_weights(): | |
| """Download model weights if not present""" | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"Weights not found. Downloading from {MODEL_URL}...") | |
| try: | |
| response = requests.get(MODEL_URL, stream=True, timeout=30) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(MODEL_PATH, "wb") as f: | |
| downloaded = 0 | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if total_size > 0: | |
| progress = (downloaded / total_size) * 100 | |
| print(f"Download progress: {progress:.1f}%", end='\r') | |
| print("\nDownload complete.") | |
| except Exception as e: | |
| if os.path.exists(MODEL_PATH): | |
| os.remove(MODEL_PATH) | |
| raise RuntimeError(f"Error downloading weights: {e}\n" | |
| "Please ensure 'best_model.pth' is in the same folder.") | |
| def load_model(): | |
| """Load and prepare the model""" | |
| print(f"Loading {MODEL_ARCH} on {DEVICE}...") | |
| try: | |
| model = timm.create_model( | |
| "hf-hub:" + MODEL_ARCH, | |
| pretrained=False, | |
| num_classes=len(CLASSES), | |
| ) | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) | |
| # Support both raw state_dict and training checkpoint format | |
| if "model_state_dict" in checkpoint: | |
| state_dict = checkpoint["model_state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| # Warmup | |
| warmup_model(model) | |
| return model | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to load model. Ensure 'timm' is installed and internet is " | |
| f"active for config download.\nError: {e}" | |
| ) | |
| def warmup_model(model): | |
| """Warmup model to avoid slow first inference""" | |
| print("Warming up model...") | |
| dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE) | |
| with torch.no_grad(): | |
| model(dummy_input) | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| print("Warmup complete.") | |
| # Initialize model | |
| download_example_images() | |
| download_model_weights() | |
| model = load_model() | |
| # ---------------------------- | |
| # 4. Inference Function | |
| # ---------------------------- | |
| def classify_image(image: Optional[Image.Image]) -> Tuple[Dict[str, float], str, str]: | |
| """ | |
| Classify an image and return results with metadata | |
| Returns: | |
| Tuple of (predictions_dict, info_message, confidence_warning) | |
| """ | |
| try: | |
| if image is None: | |
| return {}, "❌ No image uploaded", "" | |
| # Validate image | |
| validate_image(image) | |
| # Ensure RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Start timing | |
| start_time = time.time() | |
| # Resize to fixed 448x448 | |
| image_resized, resize_info = resize_to_448(image) | |
| # Normalize and add batch dimension | |
| input_tensor = normalize(image_resized).unsqueeze(0).to(DEVICE) | |
| # Inference | |
| with torch.no_grad(), torch.autocast(device_type=DEVICE, dtype=torch.bfloat16): | |
| output = model(input_tensor) | |
| probabilities = F.softmax(output, dim=-1) | |
| # Get all predictions | |
| results = {} | |
| for idx, label in enumerate(CLASSES): | |
| results[label] = probabilities[0, idx].item() | |
| # Calculate inference time | |
| inference_time = (time.time() - start_time) * 1000 # Convert to ms | |
| # Build info message | |
| info_parts = [ | |
| f"✓ {resize_info}", | |
| f"⚡ Inference time: {inference_time:.1f}ms", | |
| f"🖥️ Device: {DEVICE.upper()}" | |
| ] | |
| info_message = "\n".join(info_parts) | |
| # Check confidence | |
| max_prob = max(results.values()) | |
| confidence_warning = "" | |
| if max_prob < 0.5: | |
| confidence_warning = "⚠️ Low confidence - results may be uncertain" | |
| elif max_prob > 0.9: | |
| confidence_warning = "✓ High confidence prediction" | |
| # Cleanup | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| return results, info_message, confidence_warning | |
| except ValueError as e: | |
| return {}, f"❌ Validation Error: {str(e)}", "" | |
| except Exception as e: | |
| return {}, f"❌ Error during classification: {str(e)}", "" | |
| # ---------------------------- | |
| # 5. Gradio Interface | |
| # ---------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
| gr.Markdown( | |
| f""" | |
| # 🎨 Anime Quality Classifier | |
| ### Model: `{MODEL_ARCH}` | |
| Upload an anime image to classify its quality as **Good**, **Normal**, or **Bad**. | |
| - All images are resized to **{IMAGE_SIZE}×{IMAGE_SIZE}** pixels | |
| - Maximum file size: {MAX_FILE_SIZE_MB}MB | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| height=400 | |
| ) | |
| classify_btn = gr.Button("🔍 Classify Image", variant="primary", size="lg") | |
| clear_btn = gr.ClearButton([image_input], value="🗑️ Clear") | |
| with gr.Column(scale=1): | |
| label_output = gr.Label( | |
| num_top_classes=len(CLASSES), | |
| label="📊 Quality Predictions" | |
| ) | |
| info_output = gr.Textbox( | |
| label="ℹ️ Processing Info", | |
| lines=3, | |
| interactive=False | |
| ) | |
| confidence_output = gr.Textbox( | |
| label="Confidence Assessment", | |
| lines=1, | |
| interactive=False | |
| ) | |
| gr.Markdown("### 🖼️ Try These Examples") | |
| if EXAMPLE_IMAGES: | |
| gr.Examples( | |
| examples=[[img] for img in EXAMPLE_IMAGES], | |
| inputs=image_input, | |
| outputs=[label_output, info_output, confidence_output], | |
| fn=classify_image, | |
| cache_examples=False, | |
| label="Sample Images" | |
| ) | |
| else: | |
| gr.Markdown("*Example images failed to download. You can still upload your own images.*") | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Tips:** | |
| - Higher quality images generally get better predictions | |
| - The model works best on anime-style artwork | |
| - Try different art styles to see how the model responds | |
| """ | |
| ) | |
| # Set up event handlers | |
| classify_btn.click( | |
| fn=classify_image, | |
| inputs=image_input, | |
| outputs=[label_output, info_output, confidence_output] | |
| ) | |
| image_input.change( | |
| fn=classify_image, | |
| inputs=image_input, | |
| outputs=[label_output, info_output, confidence_output] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| show_error=True | |
| ) |