Demo-Anime / app.py
Shio-Koube's picture
Update app.py
91b1e5f verified
import gradio as gr
import torch
import torch.nn.functional as F
import timm
from PIL import Image
from torchvision import transforms
import os
import requests
import time
from typing import Dict, Tuple, Optional
from io import BytesIO
# ----------------------------
# 1. Configuration
# ----------------------------
MODEL_ARCH = "animetimm/caformer_b36.dbv4-full"
CLASSES = ["Good", "Normal", "Bad"] # Must match your training order
MODEL_PATH = "best_model.pth"
MODEL_URL = "https://huggingface.co/Shio-Koube/ConvNext-aesthetic-rater/resolve/main/best_model.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Fixed image size for inference
IMAGE_SIZE = 448
MIN_IMAGE_SIZE = 32
MAX_FILE_SIZE_MB = 10
# Example images for users to try
EXAMPLE_IMAGE_URLS = [
"https://cdn.donmai.us/sample/06/5d/__mayor_maybelle_eiyuu_densetsu_and_1_more_drawn_by_tinybiard__sample-065de2c8baabf8a25dacd181e14ce900.jpg",
"https://cdn.donmai.us/sample/b6/d1/__nicole_demara_and_aria_zenless_zone_zero_drawn_by_orcavice__sample-b6d15d4892fb58e06943692736483799.jpg",
"https://cdn.donmai.us/sample/0a/74/__tarte_and_macaron_kemono_teatime_drawn_by_mamesuzu__sample-0a7476e8f1672a9a95af9faf4a133326.jpg",
"https://cdn.donmai.us/sample/86/ee/__quiche_kemono_teatime_drawn_by_ntny__sample-86ee7123cd43a19faaa628b71662bad8.jpg",
"https://cdn.donmai.us/sample/11/43/__lance_crown_mashle_drawn_by_knata09660180__sample-11431fded8a2d4f966c1caa0ec5d515e.jpg",
"https://cdn.donmai.us/sample/76/0a/__aeria_original_drawn_by_eudetenis__sample-760a0cb3a499c002acbbdd50db899722.jpg",
]
EXAMPLE_DIR = "example_images"
EXAMPLE_IMAGES = []
def download_example_images():
"""Download example images with 1 second delay between each to avoid rate limiting"""
global EXAMPLE_IMAGES
if not os.path.exists(EXAMPLE_DIR):
os.makedirs(EXAMPLE_DIR)
print("Downloading example images...")
for i, url in enumerate(EXAMPLE_IMAGE_URLS, 1):
filename = f"example_{i}.jpg"
filepath = os.path.join(EXAMPLE_DIR, filename)
# Skip if already downloaded
if os.path.exists(filepath):
print(f"Example {i}/{len(EXAMPLE_IMAGE_URLS)}: Already exists")
EXAMPLE_IMAGES.append(filepath)
continue
try:
# Sleep before downloading (except for the first one)
if i > 1:
print(f"Waiting 1 second before downloading example {i}...")
time.sleep(1)
print(f"Downloading example {i}/{len(EXAMPLE_IMAGE_URLS)}...")
response = requests.get(url, timeout=30)
response.raise_for_status()
# Save the image
img = Image.open(BytesIO(response.content))
img.save(filepath)
EXAMPLE_IMAGES.append(filepath)
print(f"✓ Example {i} downloaded")
except Exception as e:
print(f"✗ Failed to download example {i}: {e}")
# Continue with other images even if one fails
print(f"Downloaded {len(EXAMPLE_IMAGES)}/{len(EXAMPLE_IMAGE_URLS)} example images.")
# Normalization (same as training)
normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ----------------------------
# 2. Resize to Fixed 448x448
# ----------------------------
def resize_to_448(image: Image.Image) -> Tuple[Image.Image, str]:
"""
Resize image to exactly 448x448 pixels.
Returns:
Tuple of (resized_image, info_string)
"""
original_w, original_h = image.size
# Resize directly to 448x448
resized = image.resize((IMAGE_SIZE, IMAGE_SIZE), Image.BICUBIC)
info = f"Original: {original_w}×{original_h} → Resized: {IMAGE_SIZE}×{IMAGE_SIZE}"
return resized, info
def validate_image(image: Image.Image) -> None:
"""Validate image meets minimum requirements"""
if image is None:
raise ValueError("No image provided")
w, h = image.size
if w < MIN_IMAGE_SIZE or h < MIN_IMAGE_SIZE:
raise ValueError(f"Image too small (minimum {MIN_IMAGE_SIZE}×{MIN_IMAGE_SIZE})")
# ----------------------------
# 3. Setup Model
# ----------------------------
def download_model_weights():
"""Download model weights if not present"""
if not os.path.exists(MODEL_PATH):
print(f"Weights not found. Downloading from {MODEL_URL}...")
try:
response = requests.get(MODEL_URL, stream=True, timeout=30)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(MODEL_PATH, "wb") as f:
downloaded = 0
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
progress = (downloaded / total_size) * 100
print(f"Download progress: {progress:.1f}%", end='\r')
print("\nDownload complete.")
except Exception as e:
if os.path.exists(MODEL_PATH):
os.remove(MODEL_PATH)
raise RuntimeError(f"Error downloading weights: {e}\n"
"Please ensure 'best_model.pth' is in the same folder.")
def load_model():
"""Load and prepare the model"""
print(f"Loading {MODEL_ARCH} on {DEVICE}...")
try:
model = timm.create_model(
"hf-hub:" + MODEL_ARCH,
pretrained=False,
num_classes=len(CLASSES),
)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
# Support both raw state_dict and training checkpoint format
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("Model loaded successfully.")
# Warmup
warmup_model(model)
return model
except Exception as e:
raise RuntimeError(
f"Failed to load model. Ensure 'timm' is installed and internet is "
f"active for config download.\nError: {e}"
)
def warmup_model(model):
"""Warmup model to avoid slow first inference"""
print("Warming up model...")
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE)
with torch.no_grad():
model(dummy_input)
if DEVICE == "cuda":
torch.cuda.empty_cache()
print("Warmup complete.")
# Initialize model
download_example_images()
download_model_weights()
model = load_model()
# ----------------------------
# 4. Inference Function
# ----------------------------
def classify_image(image: Optional[Image.Image]) -> Tuple[Dict[str, float], str, str]:
"""
Classify an image and return results with metadata
Returns:
Tuple of (predictions_dict, info_message, confidence_warning)
"""
try:
if image is None:
return {}, "❌ No image uploaded", ""
# Validate image
validate_image(image)
# Ensure RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Start timing
start_time = time.time()
# Resize to fixed 448x448
image_resized, resize_info = resize_to_448(image)
# Normalize and add batch dimension
input_tensor = normalize(image_resized).unsqueeze(0).to(DEVICE)
# Inference
with torch.no_grad(), torch.autocast(device_type=DEVICE, dtype=torch.bfloat16):
output = model(input_tensor)
probabilities = F.softmax(output, dim=-1)
# Get all predictions
results = {}
for idx, label in enumerate(CLASSES):
results[label] = probabilities[0, idx].item()
# Calculate inference time
inference_time = (time.time() - start_time) * 1000 # Convert to ms
# Build info message
info_parts = [
f"✓ {resize_info}",
f"⚡ Inference time: {inference_time:.1f}ms",
f"🖥️ Device: {DEVICE.upper()}"
]
info_message = "\n".join(info_parts)
# Check confidence
max_prob = max(results.values())
confidence_warning = ""
if max_prob < 0.5:
confidence_warning = "⚠️ Low confidence - results may be uncertain"
elif max_prob > 0.9:
confidence_warning = "✓ High confidence prediction"
# Cleanup
if DEVICE == "cuda":
torch.cuda.empty_cache()
return results, info_message, confidence_warning
except ValueError as e:
return {}, f"❌ Validation Error: {str(e)}", ""
except Exception as e:
return {}, f"❌ Error during classification: {str(e)}", ""
# ----------------------------
# 5. Gradio Interface
# ----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown(
f"""
# 🎨 Anime Quality Classifier
### Model: `{MODEL_ARCH}`
Upload an anime image to classify its quality as **Good**, **Normal**, or **Bad**.
- All images are resized to **{IMAGE_SIZE}×{IMAGE_SIZE}** pixels
- Maximum file size: {MAX_FILE_SIZE_MB}MB
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Image",
height=400
)
classify_btn = gr.Button("🔍 Classify Image", variant="primary", size="lg")
clear_btn = gr.ClearButton([image_input], value="🗑️ Clear")
with gr.Column(scale=1):
label_output = gr.Label(
num_top_classes=len(CLASSES),
label="📊 Quality Predictions"
)
info_output = gr.Textbox(
label="ℹ️ Processing Info",
lines=3,
interactive=False
)
confidence_output = gr.Textbox(
label="Confidence Assessment",
lines=1,
interactive=False
)
gr.Markdown("### 🖼️ Try These Examples")
if EXAMPLE_IMAGES:
gr.Examples(
examples=[[img] for img in EXAMPLE_IMAGES],
inputs=image_input,
outputs=[label_output, info_output, confidence_output],
fn=classify_image,
cache_examples=False,
label="Sample Images"
)
else:
gr.Markdown("*Example images failed to download. You can still upload your own images.*")
gr.Markdown(
"""
---
**Tips:**
- Higher quality images generally get better predictions
- The model works best on anime-style artwork
- Try different art styles to see how the model responds
"""
)
# Set up event handlers
classify_btn.click(
fn=classify_image,
inputs=image_input,
outputs=[label_output, info_output, confidence_output]
)
image_input.change(
fn=classify_image,
inputs=image_input,
outputs=[label_output, info_output, confidence_output]
)
if __name__ == "__main__":
iface.launch(
share=False,
server_name="0.0.0.0",
show_error=True
)