Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| GPU-OPTIMIZED STYLE PAINTING APP - FIXED VERSION | |
| Keeps ALL original features, just fixes what's broken | |
| FEATURES PRESERVED: | |
| β GPU acceleration with CUDA | |
| β Multiple AI style models | |
| β Real-time painting interface | |
| β Preview vs AI processing distinction | |
| β Auto-processing after delay | |
| β Batch processing mode | |
| β Pre-processed styles for speed | |
| β NEW: Intensity control for each style | |
| FIXES: | |
| β Removed eraser (as requested) - just reset button | |
| β Fixed Gradio update issues | |
| β Each apply creates new base image | |
| β Better state management | |
| """ | |
| import os | |
| os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import glob | |
| import datetime | |
| import tempfile | |
| import time | |
| import threading | |
| import zipfile | |
| import io | |
| from typing import Dict, Tuple, Optional, List | |
| import warnings | |
| import traceback | |
| warnings.filterwarnings("ignore") | |
| # Force CUDA if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(0) | |
| print("π₯ CUDA device set to 0") | |
| # =========================== | |
| # GPU SETUP (KEEP ORIGINAL) | |
| # =========================== | |
| def verify_gpu_setup(): | |
| """Verify GPU is available and working""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print("=" * 50) | |
| print("π GPU VERIFICATION") | |
| print("=" * 50) | |
| print(f"π CUDA Available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"π Device Name: {torch.cuda.get_device_name(0)}") | |
| total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| print(f"π Total GPU Memory: {total_memory:.2f} GB") | |
| print("β GPU Ready!") | |
| else: | |
| print("β CUDA NOT AVAILABLE - Running on CPU (will be slow)") | |
| print("=" * 50) | |
| return device | |
| device = verify_gpu_setup() | |
| # =========================== | |
| # MODEL ARCHITECTURE (KEEP ORIGINAL) | |
| # =========================== | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_features): | |
| super(ResidualBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_features, in_features, 3), | |
| nn.InstanceNorm2d(in_features, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_features, in_features, 3), | |
| nn.InstanceNorm2d(in_features, affine=True) | |
| ) | |
| def forward(self, x): | |
| return x + self.block(x) | |
| class Generator(nn.Module): | |
| def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=12): | |
| super(Generator, self).__init__() | |
| model = [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(input_nc, 64, 7), | |
| nn.InstanceNorm2d(64, affine=True), | |
| nn.ReLU(inplace=True) | |
| ] | |
| in_features = 64 | |
| out_features = in_features * 2 | |
| for _ in range(2): | |
| model += [ | |
| nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | |
| nn.InstanceNorm2d(out_features, affine=True), | |
| nn.ReLU(inplace=True) | |
| ] | |
| in_features = out_features | |
| out_features = in_features * 2 | |
| for _ in range(n_residual_blocks): | |
| model += [ResidualBlock(in_features)] | |
| out_features = in_features // 2 | |
| for _ in range(2): | |
| model += [ | |
| nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), | |
| nn.InstanceNorm2d(out_features, affine=True), | |
| nn.ReLU(inplace=True) | |
| ] | |
| in_features = out_features | |
| out_features = in_features // 2 | |
| model += [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(64, output_nc, 7), | |
| nn.Tanh() | |
| ] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| return self.model(x) | |
| class LegacyGenerator(nn.Module): | |
| def __init__(self, input_nc=3, output_nc=3): | |
| super(LegacyGenerator, self).__init__() | |
| model = [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(input_nc, 64, 7), | |
| nn.InstanceNorm2d(64, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), | |
| nn.InstanceNorm2d(128, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 256, 3, stride=2, padding=1), | |
| nn.InstanceNorm2d(256, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), | |
| nn.InstanceNorm2d(128, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), | |
| nn.InstanceNorm2d(64, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(64, output_nc, 7), | |
| nn.Tanh() | |
| ] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| return self.model(x) | |
| def detect_model_architecture(state_dict): | |
| """Keep original model detection""" | |
| residual_keys = [key for key in state_dict.keys() if 'model.' in key and '.block.' in key] | |
| if not residual_keys: | |
| return 0 | |
| block_indices = set() | |
| for key in residual_keys: | |
| try: | |
| parts = key.split('.') | |
| if len(parts) >= 3 and parts[2] == 'block': | |
| model_idx = int(parts[1]) | |
| block_indices.add(model_idx) | |
| except (ValueError, IndexError): | |
| continue | |
| return len(block_indices) if block_indices else 12 | |
| def create_compatible_generator(state_dict): | |
| n_residual_blocks = detect_model_architecture(state_dict) | |
| if n_residual_blocks == 0: | |
| return LegacyGenerator() | |
| else: | |
| return Generator(n_residual_blocks=n_residual_blocks) | |
| # =========================== | |
| # FIXED PROCESSING SYSTEM | |
| # =========================== | |
| class StylePaintingSystem: | |
| def __init__(self): | |
| # Model management | |
| self.style_models = {} | |
| self.loaded_generators = {} | |
| self.precomputed_styles = {} | |
| # Current state - SIMPLIFIED WITHOUT ERASER | |
| self.original_image = None | |
| self.current_base = None # Current working image (updates after each apply) | |
| self.current_display = None # What's shown in UI | |
| self.is_preview = True | |
| # Painting state (NO ERASER) | |
| self.style_masks = {} # style_key -> (mask, intensity) | |
| self.active_style = None | |
| self.active_intensity = 1.0 # NEW: intensity control | |
| # Processing state | |
| self.auto_timer = None | |
| self.auto_delay = 3.0 | |
| self.processing_lock = threading.Lock() | |
| # Transforms (KEEP ORIGINAL) | |
| self.processing_size = 512 | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((self.processing_size, self.processing_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| self.inverse_transform = transforms.Compose([ | |
| transforms.Normalize((-1, -1, -1), (2, 2, 2)), | |
| transforms.ToPILImage() | |
| ]) | |
| # Style configurations (KEEP ORIGINAL) | |
| self.style_configs = [ | |
| ('πΈ', 'Natural Bokeh', (255, 182, 193)), | |
| ('π ', 'Natural Golden', (255, 215, 0)), | |
| ('βοΈ', 'Photo Golden', (255, 165, 0)), | |
| ('π', 'Day to Night', (25, 25, 112)), | |
| ('βοΈ', 'Summer Winter', (173, 216, 230)), | |
| ('π·', 'Photo Bokeh', (255, 228, 196)), | |
| ('πΌοΈ', 'Photo Monet', (144, 238, 144)), | |
| ('π¨', 'Photo Seurat', (221, 160, 221)), | |
| ('π«οΈ', 'Foggy Clear', (220, 220, 220)) | |
| ] | |
| self.discover_models() | |
| def discover_models(self): | |
| """Keep original model discovery""" | |
| print("\n" + "="*60) | |
| print("π DISCOVERING MODELS") | |
| print("="*60) | |
| patterns = [ | |
| './models/*_best_*/', | |
| './models/*/', | |
| './models/*.pth', | |
| './models/*/*.pth' | |
| ] | |
| all_files = [] | |
| for pattern in patterns: | |
| files = glob.glob(pattern) | |
| if files: | |
| print(f"π Pattern '{pattern}' found {len(files)} items") | |
| all_files.extend(files) | |
| discovered = [] | |
| for path in all_files: | |
| if os.path.isdir(path): | |
| ab_files = glob.glob(os.path.join(path, '*generator_AB.pth')) | |
| if ab_files: | |
| folder_name = os.path.basename(path.rstrip('/')) | |
| model_name = folder_name.split('_best_')[0] if '_best_' in folder_name else folder_name | |
| discovered.append((model_name, ab_files[0])) | |
| elif path.endswith('.pth') and 'generator' in path: | |
| model_name = os.path.basename(path).replace('.pth', '').replace('_generator_AB', '') | |
| discovered.append((model_name, path)) | |
| for idx, (model_name, model_path) in enumerate(discovered[:len(self.style_configs)]): | |
| emoji, display_name, color = self.style_configs[idx] | |
| self.style_models[model_name] = { | |
| 'path': model_path, | |
| 'emoji': emoji, | |
| 'name': display_name, | |
| 'color': color | |
| } | |
| print(f"π Registered: {emoji} {display_name} ({model_name})") | |
| print(f"\nβ Registered {len(self.style_models)} style models") | |
| print("="*60 + "\n") | |
| def load_generator(self, model_key): | |
| """Keep original model loading""" | |
| if model_key in self.loaded_generators: | |
| return self.loaded_generators[model_key] | |
| if model_key not in self.style_models: | |
| return None | |
| try: | |
| model_path = self.style_models[model_key]['path'] | |
| state_dict = torch.load(model_path, map_location=device) | |
| if 'generator' in state_dict: | |
| state_dict = state_dict['generator'] | |
| generator = create_compatible_generator(state_dict) | |
| generator.load_state_dict(state_dict) | |
| generator.eval() | |
| generator = generator.to(device) | |
| if device.type == 'cuda': | |
| try: | |
| generator = generator.half() | |
| except: | |
| generator = generator.float() | |
| self.loaded_generators[model_key] = generator | |
| return generator | |
| except Exception as e: | |
| print(f"β Error loading {model_key}: {e}") | |
| return None | |
| def process_image_with_style(self, image, model_key, intensity=1.0): | |
| """Process with intensity control""" | |
| generator = self.load_generator(model_key) | |
| if generator is None: | |
| return None | |
| try: | |
| original_size = image.size | |
| img_resized = image.resize((self.processing_size, self.processing_size), Image.LANCZOS) | |
| img_tensor = self.transform(img_resized).unsqueeze(0).to(device) | |
| if device.type == 'cuda' and next(generator.parameters()).dtype == torch.float16: | |
| img_tensor = img_tensor.half() | |
| with torch.no_grad(): | |
| if device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| result_tensor = generator(img_tensor) | |
| if device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| if result_tensor.dtype == torch.float16: | |
| result_tensor = result_tensor.float() | |
| result_tensor = result_tensor.cpu() | |
| processed_img = self.inverse_transform(result_tensor.squeeze(0)) | |
| processed_img = processed_img.resize(original_size, Image.LANCZOS) | |
| # Apply intensity | |
| if intensity < 1.0: | |
| processed_array = np.array(processed_img, dtype=np.float32) | |
| original_array = np.array(image, dtype=np.float32) | |
| blended = original_array * (1 - intensity) + processed_array * intensity | |
| processed_img = Image.fromarray(blended.astype(np.uint8)) | |
| return processed_img | |
| except Exception as e: | |
| print(f"β Error processing: {e}") | |
| return None | |
| def setup_new_image(self, image, progress_callback=None): | |
| """Setup without precomputing all styles""" | |
| if image is None: | |
| return "No image provided" | |
| if image.mode == 'RGBA': | |
| rgb_image = Image.new('RGB', image.size, (255, 255, 255)) | |
| rgb_image.paste(image, mask=image.split()[3]) | |
| image = rgb_image | |
| elif image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| self.original_image = image | |
| self.current_base = image.copy() # NEW: track current base | |
| self.current_display = image | |
| self.is_preview = True | |
| self.style_masks = {} | |
| self.precomputed_styles = {} | |
| # Optionally precompute SOME styles | |
| if progress_callback: | |
| progress_callback(0.5, "Image ready! Precomputing popular styles...") | |
| # Precompute just first 3 styles | |
| for idx, (model_key, model_data) in enumerate(list(self.style_models.items())[:3]): | |
| processed = self.process_image_with_style(image, model_key) | |
| if processed: | |
| self.precomputed_styles[model_key] = processed | |
| return f"β Ready! {len(self.style_models)} styles available on {device.type.upper()}" | |
| def set_active_style(self, style_choice): | |
| """Set style (no eraser mode)""" | |
| for model_key, model_data in self.style_models.items(): | |
| if f"{model_data['emoji']} {model_data['name']}" == style_choice: | |
| self.active_style = model_key | |
| return f"π― Selected: {style_choice}" | |
| return "β Style not found" | |
| def set_intensity(self, intensity): | |
| """Set painting intensity""" | |
| self.active_intensity = intensity | |
| return f"Intensity: {int(intensity * 100)}%" | |
| def get_painted_mask(self, editor_data): | |
| """Keep original mask extraction""" | |
| if not editor_data or not editor_data.get('layers'): | |
| return None | |
| if not self.current_base: | |
| return None | |
| height, width = self.current_base.size[1], self.current_base.size[0] | |
| combined_mask = np.zeros((height, width), dtype=bool) | |
| for layer in editor_data['layers']: | |
| if layer is None: | |
| continue | |
| layer_array = np.array(layer) | |
| if layer_array.shape[:2] != (height, width): | |
| layer_pil = Image.fromarray(layer_array) | |
| layer_pil = layer_pil.resize((width, height), Image.LANCZOS) | |
| layer_array = np.array(layer_pil) | |
| if layer_array.shape[-1] == 4: | |
| mask = layer_array[:, :, 3] > 30 | |
| else: | |
| mask = np.any(layer_array > 10, axis=2) | |
| combined_mask = combined_mask | mask | |
| return combined_mask | |
| def update_painting(self, editor_data): | |
| """Update painting (no eraser logic)""" | |
| if not self.current_base: | |
| return self.current_base, "Upload an image first" | |
| painted_mask = self.get_painted_mask(editor_data) | |
| if painted_mask is None: | |
| return self.current_display, "No painting detected" | |
| if not self.active_style: | |
| return self.current_display, "Select a style first" | |
| # Store mask with intensity | |
| if self.active_style not in self.style_masks: | |
| self.style_masks[self.active_style] = (np.zeros_like(painted_mask), self.active_intensity) | |
| # Update mask | |
| existing_mask, _ = self.style_masks[self.active_style] | |
| new_pixels = painted_mask & (~existing_mask) | |
| if np.any(new_pixels): | |
| updated_mask = existing_mask | new_pixels | |
| self.style_masks[self.active_style] = (updated_mask, self.active_intensity) | |
| self.is_preview = True | |
| self.schedule_auto_process() | |
| preview = self.create_preview() | |
| style_name = self.style_models[self.active_style]['name'] | |
| return preview, f"π¨ Added {np.sum(new_pixels):,} pixels of {style_name} at {int(self.active_intensity*100)}%" | |
| return self.current_display, "Paint in new areas" | |
| def create_preview(self): | |
| """Create preview from current base""" | |
| if not self.current_base: | |
| return None | |
| result = np.array(self.current_base, dtype=np.float32) | |
| for style_key, (mask, intensity) in self.style_masks.items(): | |
| if not np.any(mask): | |
| continue | |
| if style_key not in self.style_models: | |
| continue | |
| color = np.array(self.style_models[style_key]['color'], dtype=np.float32) | |
| mask_smooth = cv2.GaussianBlur(mask.astype(np.float32), (15, 15), 0) | |
| mask_3d = np.stack([mask_smooth] * 3, axis=2) | |
| overlay = result * 0.7 + color * 0.3 | |
| result = result * (1 - mask_3d * intensity) + overlay * mask_3d * intensity | |
| self.current_display = Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) | |
| return self.current_display | |
| def apply_ai_processing(self, progress_callback=None): | |
| """Apply and create new base - FIXED FOR GRADIO UPDATE""" | |
| if not self.current_base: | |
| return self.current_base, "No image loaded" | |
| has_styles = any(np.any(mask) for mask, _ in self.style_masks.values()) | |
| if not has_styles: | |
| return self.current_base, "Nothing to process - paint first!" | |
| with self.processing_lock: | |
| # Start from current base (not original!) | |
| result = np.array(self.current_base, dtype=np.float32) | |
| applied_count = 0 | |
| for style_key, (mask, intensity) in self.style_masks.items(): | |
| if not np.any(mask): | |
| continue | |
| # Process or use cached | |
| if style_key in self.precomputed_styles: | |
| styled_img = self.precomputed_styles[style_key] | |
| else: | |
| styled_img = self.process_image_with_style(self.current_base, style_key, intensity) | |
| if styled_img: | |
| self.precomputed_styles[style_key] = styled_img | |
| if styled_img: | |
| styled_array = np.array(styled_img, dtype=np.float32) | |
| mask_smooth = cv2.GaussianBlur(mask.astype(np.float32), (21, 21), 0) | |
| mask_3d = np.stack([mask_smooth] * 3, axis=2) | |
| # Apply with intensity | |
| result = result * (1 - mask_3d * intensity) + styled_array * mask_3d * intensity | |
| applied_count += 1 | |
| # Create new base image - THIS IS THE KEY FIX | |
| new_image = Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) | |
| # Update current base for next round | |
| self.current_base = new_image.copy() | |
| self.current_display = new_image | |
| self.is_preview = False | |
| # Clear masks for next painting session | |
| self.style_masks = {} | |
| self.precomputed_styles = {} # Clear cache since base changed | |
| # Cancel auto-timer | |
| if self.auto_timer: | |
| self.auto_timer.cancel() | |
| self.auto_timer = None | |
| # Force new object for Gradio | |
| return new_image, f"π₯ Applied {applied_count} styles! Ready for more painting." | |
| def schedule_auto_process(self): | |
| """Keep original auto-processing""" | |
| if self.auto_timer: | |
| self.auto_timer.cancel() | |
| self.auto_timer = threading.Timer(self.auto_delay, self.auto_process) | |
| self.auto_timer.daemon = True | |
| self.auto_timer.start() | |
| def auto_process(self): | |
| """Auto process callback""" | |
| self.apply_ai_processing() | |
| def reset_all(self): | |
| """Reset to original image""" | |
| if self.original_image: | |
| self.current_base = self.original_image.copy() | |
| self.current_display = self.original_image | |
| self.style_masks = {} | |
| self.precomputed_styles = {} | |
| self.is_preview = True | |
| return self.original_image, "π Reset to original" | |
| return None, "No image loaded" | |
| def process_batch(self, images, selected_styles_with_intensity, progress_callback=None): | |
| """Batch process with intensity""" | |
| results = {} | |
| total = len(images) * len(selected_styles_with_intensity) | |
| current = 0 | |
| for img_idx, image in enumerate(images): | |
| img_results = {} | |
| for style_text, intensity in selected_styles_with_intensity: | |
| current += 1 | |
| if progress_callback: | |
| progress_callback(current / total, f"Processing image {img_idx+1} with {style_text} at {int(intensity*100)}%") | |
| # Find the model key | |
| model_key = None | |
| for key, data in self.style_models.items(): | |
| if f"{data['emoji']} {data['name']}" == style_text: | |
| model_key = key | |
| break | |
| if model_key: | |
| processed = self.process_image_with_style(image, model_key, intensity) | |
| if processed: | |
| img_results[f"{style_text}_{int(intensity*100)}"] = processed | |
| results[f"image_{img_idx}"] = img_results | |
| return results | |
| # =========================== | |
| # GLOBAL SYSTEM INSTANCE | |
| # =========================== | |
| system = StylePaintingSystem() | |
| # =========================== | |
| # GRADIO INTERFACE FUNCTIONS (FIXED) | |
| # =========================== | |
| def on_image_upload(image, progress=gr.Progress()): | |
| """Handle image upload""" | |
| if image is None: | |
| return None, "Please upload an image", gr.update(value=None) | |
| def progress_cb(val, desc): | |
| progress(val, desc) | |
| status = system.setup_new_image(image, progress_cb) | |
| # Return image to both display and editor | |
| return image, status, image | |
| def on_style_select(style_choice): | |
| """Handle style selection""" | |
| if not style_choice: | |
| return "Select a style" | |
| return system.set_active_style(style_choice) | |
| def on_intensity_change(intensity): | |
| """Handle intensity change""" | |
| return system.set_intensity(intensity) | |
| def on_paint_change(editor_data): | |
| """Handle painting changes""" | |
| if not editor_data: | |
| return system.current_display, "Paint to add styles" | |
| result_img, status = system.update_painting(editor_data) | |
| return result_img, status | |
| def on_apply_ai(): | |
| """Apply AI processing - FIXED FOR IMMEDIATE UPDATE""" | |
| # Cancel any auto-processing | |
| if system.auto_timer: | |
| system.auto_timer.cancel() | |
| system.auto_timer = None | |
| # Apply processing | |
| result_img, status = system.apply_ai_processing() | |
| # Force new object for Gradio | |
| if result_img: | |
| # Create a completely new image object | |
| import io | |
| buffer = io.BytesIO() | |
| result_img.save(buffer, format='PNG') | |
| buffer.seek(0) | |
| new_img = Image.open(buffer) | |
| # Return new image to BOTH displays to sync them | |
| return new_img, status, new_img | |
| return None, "Processing failed", None | |
| def on_reset(): | |
| """Reset to original""" | |
| result_img, status = system.reset_all() | |
| return result_img, status, result_img | |
| def on_batch_process(file_list, selected_styles, intensities, progress=gr.Progress()): | |
| """Batch processing with intensity""" | |
| if not file_list or not selected_styles: | |
| return None, [], "Upload images and select styles" | |
| # Pair styles with intensities | |
| styles_with_intensity = [] | |
| for i, style in enumerate(selected_styles): | |
| intensity = intensities[i] if i < len(intensities) else 1.0 | |
| styles_with_intensity.append((style, intensity)) | |
| # Load images | |
| images = [] | |
| for file_path in file_list: | |
| try: | |
| img = Image.open(file_path) | |
| if img.mode == 'RGBA': | |
| rgb_img = Image.new('RGB', img.size, (255, 255, 255)) | |
| rgb_img.paste(img, mask=img.split()[3]) | |
| img = rgb_img | |
| images.append(img) | |
| except Exception as e: | |
| print(f"Error loading {file_path}: {e}") | |
| if not images: | |
| return None, [], "Failed to load images" | |
| def progress_cb(val, desc): | |
| progress(val, desc) | |
| # Process | |
| results = system.process_batch(images, styles_with_intensity, progress_cb) | |
| # Create outputs | |
| preview_images = [] | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, 'w') as zf: | |
| for img_key, img_results in results.items(): | |
| for style_name, processed_img in img_results.items(): | |
| filename = f"{img_key}_{style_name.replace(' ', '_')}.png" | |
| preview_images.append(processed_img) | |
| img_buffer = io.BytesIO() | |
| processed_img.save(img_buffer, format='PNG') | |
| zf.writestr(filename, img_buffer.getvalue()) | |
| # Save zip | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| zip_path = os.path.join(tempfile.gettempdir(), f"batch_{timestamp}.zip") | |
| with open(zip_path, 'wb') as f: | |
| f.write(zip_buffer.getvalue()) | |
| return zip_path, preview_images[:8], f"β Processed {len(images)} images" | |
| # =========================== | |
| # CREATE GRADIO INTERFACE (FIXED) | |
| # =========================== | |
| def create_interface(): | |
| """Create the Gradio interface with fixes""" | |
| with gr.Blocks(title="GPU Style Painting", css=""" | |
| .paint-canvas { height: 600px !important; } | |
| """) as interface: | |
| gr.Markdown("# π¨ GPU-Optimized Style Painting") | |
| gr.Markdown(f"**Device:** {device.type.upper()} | **Auto-process:** {system.auto_delay}s") | |
| with gr.Tabs(): | |
| # PAINTING TAB | |
| with gr.TabItem("π¨ Paint Mode"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Image" | |
| ) | |
| upload_status = gr.Textbox( | |
| label="Status", | |
| value="Upload an image to start" | |
| ) | |
| # Style selection (NO ERASER) | |
| style_choices = [] | |
| for key, data in system.style_models.items(): | |
| style_choices.append(f"{data['emoji']} {data['name']}") | |
| style_selector = gr.Radio( | |
| choices=style_choices, | |
| label="Select Style", | |
| value=style_choices[0] if style_choices else None | |
| ) | |
| style_status = gr.Textbox( | |
| label="Style Status", | |
| value="Select a style to paint" | |
| ) | |
| # NEW: Intensity slider | |
| intensity_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Style Intensity" | |
| ) | |
| intensity_status = gr.Textbox( | |
| label="Intensity", | |
| value="Intensity: 100%" | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| apply_btn = gr.Button("π₯ Apply AI", variant="primary") | |
| reset_btn = gr.Button("π Reset", variant="secondary") | |
| with gr.Column(scale=2): | |
| # FIXED: Separate display and editor | |
| result_image = gr.Image( | |
| label="Result", | |
| height=600, | |
| type="pil", | |
| interactive=False # Display only | |
| ) | |
| painting_canvas = gr.ImageEditor( | |
| type="pil", | |
| label="Paint Canvas", | |
| brush=gr.Brush(default_size=30), | |
| height=600, | |
| image_mode="RGB" | |
| ) | |
| paint_status = gr.Textbox( | |
| label="Paint Status", | |
| value="Ready" | |
| ) | |
| # BATCH TAB (WITH INTENSITY) | |
| with gr.TabItem("π¦ Batch Mode"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| batch_files = gr.File( | |
| label="Upload Images", | |
| file_count="multiple", | |
| file_types=["image"] | |
| ) | |
| batch_styles = gr.CheckboxGroup( | |
| choices=[f"{data['emoji']} {data['name']}" for data in system.style_models.values()], | |
| label="Select Styles" | |
| ) | |
| # NEW: Intensity for each style | |
| batch_intensities = gr.Dataframe( | |
| headers=["Style", "Intensity"], | |
| datatype=["str", "number"], | |
| col_count=(2, "fixed"), | |
| value=[["Style 1", 1.0], ["Style 2", 1.0], ["Style 3", 1.0]], | |
| label="Style Intensities (0.1-1.0)" | |
| ) | |
| batch_btn = gr.Button("π Process Batch", variant="primary") | |
| batch_status = gr.Textbox(label="Status") | |
| with gr.Column(): | |
| batch_download = gr.File(label="Download ZIP") | |
| batch_preview = gr.Gallery( | |
| label="Preview", | |
| columns=4, | |
| height=400 | |
| ) | |
| # Wire up events | |
| input_image.change( | |
| fn=on_image_upload, | |
| inputs=[input_image], | |
| outputs=[result_image, upload_status, painting_canvas] | |
| ) | |
| style_selector.change( | |
| fn=on_style_select, | |
| inputs=[style_selector], | |
| outputs=[style_status] | |
| ) | |
| intensity_slider.change( | |
| fn=on_intensity_change, | |
| inputs=[intensity_slider], | |
| outputs=[intensity_status] | |
| ) | |
| painting_canvas.change( | |
| fn=on_paint_change, | |
| inputs=[painting_canvas], | |
| outputs=[result_image, paint_status], | |
| queue=False # Immediate preview | |
| ) | |
| apply_btn.click( | |
| fn=on_apply_ai, | |
| inputs=[], | |
| outputs=[result_image, paint_status, painting_canvas], | |
| queue=False # IMMEDIATE execution | |
| ) | |
| reset_btn.click( | |
| fn=on_reset, | |
| inputs=[], | |
| outputs=[result_image, paint_status, painting_canvas] | |
| ) | |
| # Batch processing | |
| def prepare_batch_intensities(styles): | |
| """Create intensity dataframe based on selected styles""" | |
| return [[style, 1.0] for style in styles] | |
| batch_styles.change( | |
| fn=prepare_batch_intensities, | |
| inputs=[batch_styles], | |
| outputs=[batch_intensities] | |
| ) | |
| def process_batch_with_df(files, styles, intensity_df): | |
| """Process batch using dataframe intensities""" | |
| intensities = [row[1] for row in intensity_df.values if row[0] in styles] | |
| return on_batch_process(files, styles, intensities) | |
| batch_btn.click( | |
| fn=process_batch_with_df, | |
| inputs=[batch_files, batch_styles, batch_intensities], | |
| outputs=[batch_download, batch_preview, batch_status] | |
| ) | |
| gr.Markdown(""" | |
| ## π Instructions | |
| 1. **Upload** an image to start | |
| 2. **Select** a style and adjust intensity | |
| 3. **Paint** on the canvas - see instant preview | |
| 4. **Apply AI** to process the painted areas | |
| 5. **Continue** painting more styles - they build on previous results | |
| 6. **Reset** to return to original image | |
| **Features:** | |
| - π GPU accelerated processing | |
| - π¨ Multiple AI styles with intensity control | |
| - β‘ Instant preview with color overlays | |
| - π₯ Progressive application (each apply builds on previous) | |
| - π¦ Batch processing with per-style intensity | |
| """) | |
| return interface | |
| # =========================== | |
| # LAUNCH | |
| # =========================== | |
| if __name__ == "__main__": | |
| print("π Starting GPU Style Painting App...") | |
| print(f"π Device: {device}") | |
| print(f"β° Auto-process delay: {system.auto_delay}s") | |
| interface = create_interface() | |
| interface.queue() | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |