#!/usr/bin/env python """ STYLE TRANSFER APP - Streamlit Version with Regional Transformations All existing features preserved + new local painting capabilities + Unsplash integration """ import os os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' os.environ['TORCH_HOME'] = '/tmp/torch_cache' os.environ['HF_HOME'] = '/tmp/hf_cache' os.makedirs('/tmp/torch_cache', exist_ok=True) os.makedirs('/tmp/hf_cache', exist_ok=True) import streamlit as st from streamlit_drawable_canvas import st_canvas import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torchvision.models as models from torch.utils.data import Dataset, DataLoader from PIL import Image, ImageDraw, ImageFont import numpy as np import glob import datetime import traceback import uuid import warnings import zipfile import io import json import time import shutil import requests import scipy try: import cv2 VIDEO_PROCESSING_AVAILABLE = True except ImportError: VIDEO_PROCESSING_AVAILABLE = False print("OpenCV not available - video processing disabled") import tempfile from pathlib import Path import colorsys warnings.filterwarnings("ignore") # Set page config st.set_page_config( page_title="Style Transfer Studio", # page_icon="", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better UI st.markdown(""" """, unsafe_allow_html=True) # Force CUDA if available if torch.cuda.is_available(): torch.cuda.set_device(0) # Set CUDA to be deterministic for consistency torch.backends.cudnn.benchmark = True print("CUDA device set") print(f"CUDA version: {torch.version.cuda}") print(f"PyTorch version: {torch.__version__}") # GPU SETUP device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") if device.type == 'cuda': print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") print(f"Current GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB") # =========================== # UNSPLASH API INTEGRATION # =========================== class UnsplashAPI: """Simple Unsplash API integration""" def __init__(self, access_key=None): # Use environment variable only - no secrets.toml warning if access_key: self.access_key = access_key else: self.access_key = os.environ.get("UNSPLASH_ACCESS_KEY") self.base_url = "https://api.unsplash.com" def search_photos(self, query, per_page=20, page=1, orientation=None): """Search photos on Unsplash""" if not self.access_key: return None, "No Unsplash API key configured" headers = {"Authorization": f"Client-ID {self.access_key}"} params = { "query": query, "per_page": per_page, "page": page } if orientation: params["orientation"] = orientation # "landscape", "portrait", "squarish" try: response = requests.get( f"{self.base_url}/search/photos", headers=headers, params=params, timeout=10 ) response.raise_for_status() return response.json(), None except requests.exceptions.RequestException as e: return None, f"Error searching Unsplash: {str(e)}" def get_random_photos(self, count=12, collections=None, query=None): """Get random photos from Unsplash""" if not self.access_key: return None, "No Unsplash API key configured" headers = {"Authorization": f"Client-ID {self.access_key}"} params = {"count": count} if collections: params["collections"] = collections if query: params["query"] = query try: response = requests.get( f"{self.base_url}/photos/random", headers=headers, params=params, timeout=10 ) response.raise_for_status() return response.json(), None except requests.exceptions.RequestException as e: return None, f"Error getting random photos: {str(e)}" def download_photo(self, photo_url, size="regular"): """Download photo from URL""" try: # Add fm=jpg&q=80 for consistent format and quality if "?" in photo_url: photo_url += "&fm=jpg&q=80" else: photo_url += "?fm=jpg&q=80" response = requests.get(photo_url, timeout=30) response.raise_for_status() return Image.open(io.BytesIO(response.content)).convert('RGB') except Exception as e: st.error(f"Error downloading image: {str(e)}") return None def trigger_download(self, download_location): """Trigger download event (required by Unsplash API)""" if not self.access_key or not download_location: return headers = {"Authorization": f"Client-ID {self.access_key}"} try: requests.get(download_location, headers=headers, timeout=5) except: pass # Don't fail if tracking fails # =========================== # MODEL ARCHITECTURES # =========================== class LightweightResidualBlock(nn.Module): """Lightweight residual block with depthwise separable convolutions""" def __init__(self, channels): super(LightweightResidualBlock, self).__init__() self.depthwise = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, 3, groups=channels), nn.InstanceNorm2d(channels, affine=True), nn.ReLU(inplace=True) ) self.pointwise = nn.Sequential( nn.Conv2d(channels, channels, 1), nn.InstanceNorm2d(channels, affine=True) ) def forward(self, x): return x + self.pointwise(self.depthwise(x)) class ResidualBlock(nn.Module): """Standard residual block for CycleGAN""" def __init__(self, in_features): super(ResidualBlock, self).__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features, affine=True), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features, affine=True) ) def forward(self, x): return x + self.block(x) class Generator(nn.Module): def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=9): super(Generator, self).__init__() # Initial convolution block model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), nn.InstanceNorm2d(64, affine=True), nn.ReLU(inplace=True) ] # Downsampling in_features = 64 out_features = in_features * 2 for _ in range(2): model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features, affine=True), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features * 2 # Residual blocks for _ in range(n_residual_blocks): model += [ResidualBlock(in_features)] # Upsampling out_features = in_features // 2 for _ in range(2): model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features, affine=True), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features // 2 # Output layer model += [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) class LightweightStyleNet(nn.Module): """Lightweight network for fast style transfer training""" def __init__(self, n_residual_blocks=5): super(LightweightStyleNet, self).__init__() # Encoder self.encoder = nn.Sequential( nn.ReflectionPad2d(3), nn.Conv2d(3, 32, 9, stride=1), nn.InstanceNorm2d(32, affine=True), nn.ReLU(inplace=True), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.InstanceNorm2d(64, affine=True), nn.ReLU(inplace=True), nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.InstanceNorm2d(128, affine=True), nn.ReLU(inplace=True) ) # Residual blocks res_blocks = [] for _ in range(n_residual_blocks): res_blocks.append(LightweightResidualBlock(128)) self.res_blocks = nn.Sequential(*res_blocks) # Decoder self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(64, affine=True), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(32, affine=True), nn.ReLU(inplace=True), nn.ReflectionPad2d(3), nn.Conv2d(32, 3, 9, stride=1), nn.Tanh() ) def forward(self, x): h = self.encoder(x) h = self.res_blocks(h) h = self.decoder(h) return h class SimpleVGGFeatures(nn.Module): """Extract features from VGG19 for perceptual loss calculation""" def __init__(self): super(SimpleVGGFeatures, self).__init__() try: vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features except: vgg = models.vgg19(pretrained=True).features self.features = nn.Sequential(*list(vgg.children())[:21]) for param in self.parameters(): param.requires_grad = False def forward(self, x): return self.features(x) # =========================== # ADAIN ARCHITECTURE # =========================== class AdaIN(nn.Module): """Adaptive Instance Normalization layer""" def __init__(self): super(AdaIN, self).__init__() def calc_mean_std(self, feat, eps=1e-5): # Calculate mean and std for AdaIN size = feat.size() assert (len(size) == 4) N, C = size[:2] feat_var = feat.view(N, C, -1).var(dim=2) + eps feat_std = feat_var.sqrt().view(N, C, 1, 1) feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) return feat_mean, feat_std def forward(self, content_feat, style_feat): size = content_feat.size() style_mean, style_std = self.calc_mean_std(style_feat) content_mean, content_std = self.calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size) class VGGEncoder(nn.Module): """VGG-based encoder for AdaIN""" def __init__(self): super(VGGEncoder, self).__init__() # Load pretrained VGG19 try: vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features except: vgg = models.vgg19(pretrained=True).features # Encoder uses layers up to relu4_1 self.enc1 = nn.Sequential(*list(vgg.children())[:2]) # conv1_1, relu1_1 self.enc2 = nn.Sequential(*list(vgg.children())[2:7]) # up to relu2_1 self.enc3 = nn.Sequential(*list(vgg.children())[7:12]) # up to relu3_1 self.enc4 = nn.Sequential(*list(vgg.children())[12:21]) # up to relu4_1 # Freeze encoder weights for param in self.parameters(): param.requires_grad = False def encode(self, x): """Get only the final features for AdaIN""" h1 = self.enc1(x) h2 = self.enc2(h1) h3 = self.enc3(h2) h4 = self.enc4(h3) return h4 def forward(self, x): h1 = self.enc1(x) h2 = self.enc2(h1) h3 = self.enc3(h2) h4 = self.enc4(h3) return h4, h3, h2, h1 # Return intermediate features for skip connections class AdaINDecoder(nn.Module): """Decoder for AdaIN style transfer""" def __init__(self): super(AdaINDecoder, self).__init__() # Decoder mirrors encoder but in reverse self.dec4 = nn.Sequential( nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(512, 256, (3, 3)), nn.ReLU(), nn.Upsample(scale_factor=2, mode='nearest'), ) self.dec3 = nn.Sequential( nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(256, 256, (3, 3)), nn.ReLU(), nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(256, 128, (3, 3)), nn.ReLU(), nn.Upsample(scale_factor=2, mode='nearest'), ) self.dec2 = nn.Sequential( nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(128, 128, (3, 3)), nn.ReLU(), nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(128, 64, (3, 3)), nn.ReLU(), nn.Upsample(scale_factor=2, mode='nearest'), ) self.dec1 = nn.Sequential( nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(64, 64, (3, 3)), nn.ReLU(), nn.ReflectionPad2d((1, 1, 1, 1)), nn.Conv2d(64, 3, (3, 3)), ) def forward(self, x): h = self.dec4(x) h = self.dec3(h) h = self.dec2(h) h = self.dec1(h) return h class AdaINStyleTransfer(nn.Module): """Complete AdaIN style transfer network""" def __init__(self): super(AdaINStyleTransfer, self).__init__() self.encoder = VGGEncoder() self.decoder = AdaINDecoder() self.adain = AdaIN() # Only decoder needs to be trained self.encoder.eval() for param in self.encoder.parameters(): param.requires_grad = False def encode(self, x): return self.encoder.encode(x) # Use the encode method def forward(self, content, style, alpha=1.0): # Encode content and style content_feat = self.encode(content) style_feat = self.encode(style) # Apply AdaIN feat = self.adain(content_feat, style_feat) # Alpha blending in feature space if alpha < 1.0: feat = alpha * feat + (1 - alpha) * content_feat # Decode return self.decoder(feat) # =========================== # DATASET AND LOSS FUNCTIONS # =========================== class StyleTransferDataset(Dataset): """Dataset for training style transfer models with augmentation support""" def __init__(self, content_dir, transform=None, augment_factor=1): self.content_dir = Path(content_dir) self.transform = transform self.augment_factor = augment_factor extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] self.images = [] for ext in extensions: self.images.extend(list(self.content_dir.glob(ext))) self.images.extend(list(self.content_dir.glob(ext.upper()))) print(f"Found {len(self.images)} content images") self.augmented_images = self.images * self.augment_factor if self.augment_factor > 1: print(f"Dataset augmented {self.augment_factor}x to {len(self.augmented_images)} samples") def __len__(self): return len(self.augmented_images) def __getitem__(self, idx): img_path = self.augmented_images[idx % len(self.images)] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image class PerceptualLoss(nn.Module): """Perceptual loss using VGG features""" def __init__(self, vgg_features): super(PerceptualLoss, self).__init__() self.vgg = vgg_features self.mse = nn.MSELoss() def gram_matrix(self, features): b, c, h, w = features.size() features = features.view(b, c, h * w) gram = torch.bmm(features, features.transpose(1, 2)) return gram / (c * h * w) def forward(self, generated, content, style, content_weight=1.0, style_weight=1e5): gen_feat = self.vgg(generated) content_feat = self.vgg(content) style_feat = self.vgg(style) content_loss = self.mse(gen_feat, content_feat) gen_gram = self.gram_matrix(gen_feat) style_gram = self.gram_matrix(style_feat) style_loss = self.mse(gen_gram, style_gram) total_loss = content_weight * content_loss + style_weight * style_loss return total_loss, content_loss, style_loss # =========================== # VIDEO PROCESSING # =========================== class VideoProcessor: """Process videos frame by frame with style transfer""" def __init__(self, system): self.system = system def process_video(self, video_path, style_configs, blend_mode, progress_callback=None): """Process a video file with style transfer""" if not VIDEO_PROCESSING_AVAILABLE: print("Video processing requires OpenCV (cv2) - please install it") return None try: # Handle both string path and file object if hasattr(video_path, 'name'): video_path = video_path.name # Open video cap = cv2.VideoCapture(video_path) fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Create temporary output file - always start with mp4 temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) temp_output.close() # Try different codecs - prioritize MP4-compatible ones codecs_to_try = [ ('mp4v', '.mp4'), # MPEG-4 - most compatible MP4 codec ('MP4V', '.mp4'), # Alternative case ('FMP4', '.mp4'), # Another MPEG-4 variant ('DIVX', '.mp4'), # DivX (sometimes works with MP4) ('XVID', '.avi'), # If MP4 fails, try AVI ('MJPG', '.avi'), # Motion JPEG - fallback ('DIV3', '.avi'), # DivX3 (0, '.avi') # Uncompressed (last resort) ] out = None output_path = None used_codec = None for codec_str, ext in codecs_to_try: try: # Update output filename with appropriate extension if ext == '.mp4': output_path = temp_output.name else: output_path = temp_output.name.replace('.mp4', ext) if codec_str == 0: fourcc = 0 print("Using uncompressed video (larger file size)") else: fourcc = cv2.VideoWriter_fourcc(*codec_str) print(f"Trying codec: {codec_str} for {ext}") # Create writer with specific parameters out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True) if out.isOpened(): # Test write a black frame to ensure it really works test_frame = np.zeros((height, width, 3), dtype=np.uint8) out.write(test_frame) out.release() # Re-open for actual writing out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True) if out.isOpened(): used_codec = codec_str print(f"✓ Successfully using codec: {codec_str} with {ext}") break out.release() out = None except Exception as e: print(f"Failed with codec {codec_str}: {e}") if out: out.release() out = None continue if out is None or not out.isOpened(): # Last resort: save frames as images and create video differently print("Standard codecs failed. Trying alternative approach...") return self._process_with_frame_saving(cap, style_configs, blend_mode, fps, width, height, total_frames, progress_callback) # Process frames frame_count = 0 while True: ret, frame = cap.read() if not ret: break # Convert BGR to RGB and to PIL Image rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_frame = Image.fromarray(rgb_frame) # Apply style transfer styled_frame = self.system.blend_styles(pil_frame, style_configs, blend_mode) # Convert back to BGR for video styled_array = np.array(styled_frame) bgr_frame = cv2.cvtColor(styled_array, cv2.COLOR_RGB2BGR) # Ensure frame is correct size and type if bgr_frame.shape[:2] != (height, width): bgr_frame = cv2.resize(bgr_frame, (width, height)) out.write(bgr_frame) frame_count += 1 if progress_callback and frame_count % 10 == 0: progress = frame_count / total_frames progress_callback(progress, f"Processing frame {frame_count}/{total_frames}") cap.release() out.release() # Verify the output file if not os.path.exists(output_path) or os.path.getsize(output_path) < 1000: print(f"Output file is too small or doesn't exist (size: {os.path.getsize(output_path) if os.path.exists(output_path) else 0} bytes)") return None print(f"Video successfully saved to: {output_path}") print(f"File size: {os.path.getsize(output_path) / 1024 / 1024:.2f} MB") print(f"Format: {os.path.splitext(output_path)[1]}") print(f"Codec used: {used_codec}") # Clean up original temp file if different if output_path != temp_output.name and os.path.exists(temp_output.name): try: os.unlink(temp_output.name) except: pass return output_path except Exception as e: print(f"Error processing video: {e}") traceback.print_exc() return None def _process_with_frame_saving(self, cap, style_configs, blend_mode, fps, width, height, total_frames, progress_callback): """Alternative processing method: save frames then combine""" try: print("Using frame-saving fallback method...") temp_dir = tempfile.mkdtemp() frame_count = 0 frame_paths = [] # Process and save frames while True: ret, frame = cap.read() if not ret: break # Process frame rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_frame = Image.fromarray(rgb_frame) styled_frame = self.system.blend_styles(pil_frame, style_configs, blend_mode) # Save frame frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png") styled_frame.save(frame_path) frame_paths.append(frame_path) frame_count += 1 if progress_callback and frame_count % 10 == 0: progress = frame_count / total_frames progress_callback(progress, f"Processing frame {frame_count}/{total_frames}") cap.release() if not frame_paths: return None # Try to create video from frames output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name # Read first frame to get size first_frame = cv2.imread(frame_paths[0]) h, w = first_frame.shape[:2] # Try simple mp4v codec fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) if out.isOpened(): for frame_path in frame_paths: frame = cv2.imread(frame_path) out.write(frame) out.release() # Clean up frames shutil.rmtree(temp_dir) if os.path.exists(output_path) and os.path.getsize(output_path) > 1000: print(f"Successfully created video using frame-saving method") return output_path # Clean up shutil.rmtree(temp_dir) return None except Exception as e: print(f"Frame-saving method failed: {e}") if 'temp_dir' in locals() and os.path.exists(temp_dir): shutil.rmtree(temp_dir) return None # =========================== # MAIN STYLE TRANSFER SYSTEM # =========================== class StyleTransferSystem: def __init__(self): self.device = device self.cyclegan_models = {} self.loaded_generators = {} self.lightweight_models = {} self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.inverse_transform = transforms.Compose([ transforms.Normalize((-1, -1, -1), (2, 2, 2)), transforms.ToPILImage() ]) self.vgg_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.discover_cyclegan_models() self.models_dir = '/tmp/trained_models' os.makedirs(self.models_dir, exist_ok=True) if VIDEO_PROCESSING_AVAILABLE: self.video_processor = VideoProcessor(self) # Test GPU functionality if self.device.type == 'cuda': self._test_gpu() def _test_gpu(self): """Test if GPU is working correctly""" try: print("\nTesting GPU functionality...") with torch.no_grad(): # Create a small test tensor test_tensor = torch.randn(1, 3, 256, 256).to(self.device) print(f"Test tensor device: {test_tensor.device}") # Perform a simple operation result = test_tensor * 2.0 print(f"Result device: {result.device}") # Check memory usage print(f"GPU memory after test: {torch.cuda.memory_allocated() / 1e9:.3f} GB") # Clean up del test_tensor, result torch.cuda.empty_cache() print("✓ GPU test successful!\n") except Exception as e: print(f"✗ GPU test failed: {e}\n") traceback.print_exc() def discover_cyclegan_models(self): """Find all available CycleGAN models including both AB and BA directions""" print("\nDiscovering CycleGAN models...") # Updated patterns to match your directory structure patterns = [ './models/*_best_*/*generator_*.pth', './models/*_best_*/*.pth', './models/*/*generator*.pth', './models/*/*.pth' ] all_files = set() for pattern in patterns: files = glob.glob(pattern) if files: print(f"Found in {pattern}: {len(files)} items") all_files.update(files) # Also check if models directory exists and list contents if os.path.exists('./models'): print(f"\nModels directory contents:") for folder in os.listdir('./models'): folder_path = os.path.join('./models', folder) if os.path.isdir(folder_path): print(f" {folder}/") for file in os.listdir(folder_path): print(f" - {file}") if file.endswith('.pth'): all_files.add(os.path.join(folder_path, file)) # Group files by base model name model_files = {} for path in all_files: # Skip normal models if 'normal' in path.lower(): continue filename = os.path.basename(path) folder_name = os.path.basename(os.path.dirname(path)) # Extract base name from folder name if '_best_' in folder_name: base_name = folder_name.split('_best_')[0] else: base_name = folder_name if base_name not in model_files: model_files[base_name] = {'AB': None, 'BA': None} # Check filename for direction if 'generator_AB' in filename or 'g_AB' in filename or 'G_AB' in filename: model_files[base_name]['AB'] = path elif 'generator_BA' in filename or 'g_BA' in filename or 'G_BA' in filename: model_files[base_name]['BA'] = path elif 'generator' in filename.lower() and not any(x in filename for x in ['AB', 'BA']): # If no direction specified, assume it's AB if model_files[base_name]['AB'] is None: model_files[base_name]['AB'] = path # Create display names for models model_display_map = { 'photo_bokeh': ('Bokeh', 'Sharp'), 'photo_golden': ('Golden Hour', 'Normal Light'), 'photo_monet': ('Monet Style', 'Photo'), 'photo_seurat': ('Seurat Style', 'Photo'), 'day_night': ('Night', 'Day'), 'summer_winter': ('Winter', 'Summer'), 'foggy_clear': ('Clear', 'Foggy') } # Register available models for base_name, files in model_files.items(): clean_name = base_name.lower().replace('-', '_') if clean_name in model_display_map: style_from, style_to = model_display_map[clean_name] # Register AB direction if available if files['AB']: display_name = f"{style_to} to {style_from}" model_key = f"{clean_name}_AB" self.cyclegan_models[model_key] = { 'path': files['AB'], 'name': display_name, 'base_name': base_name, 'direction': 'AB' } print(f"Registered: {display_name} ({model_key}) -> {files['AB']}") # Register BA direction if available if files['BA']: display_name = f"{style_from} to {style_to}" model_key = f"{clean_name}_BA" self.cyclegan_models[model_key] = { 'path': files['BA'], 'name': display_name, 'base_name': base_name, 'direction': 'BA' } print(f"Registered: {display_name} ({model_key}) -> {files['BA']}") if not self.cyclegan_models: print("No CycleGAN models found!") print("Make sure your model files are in the ./models directory") else: print(f"\nFound {len(self.cyclegan_models)} CycleGAN models\n") def detect_architecture(self, state_dict): """Detect the number of residual blocks in CycleGAN model""" residual_keys = [k for k in state_dict.keys() if 'model.' in k and '.block.' in k] if not residual_keys: return 9 block_indices = set() for key in residual_keys: parts = key.split('.') for i in range(len(parts) - 1): if parts[i] == 'model' and parts[i+1].isdigit(): block_indices.add(int(parts[i+1])) break n_blocks = len(block_indices) return n_blocks if n_blocks > 0 else 9 def load_cyclegan_model(self, model_key): """Load a CycleGAN model""" if model_key in self.loaded_generators: # Ensure cached model is on the correct device model = self.loaded_generators[model_key] model = model.to(self.device) return model if model_key not in self.cyclegan_models: print(f"Model {model_key} not found!") return None model_info = self.cyclegan_models[model_key] try: print(f"Loading {model_info['name']} from {model_info['path']}...") print(f"Target device: {self.device}") # Load with explicit map_location to ensure it goes to the right device state_dict = torch.load(model_info['path'], map_location=self.device) if 'generator' in state_dict: state_dict = state_dict['generator'] n_blocks = self.detect_architecture(state_dict) print(f"Detected {n_blocks} residual blocks") generator = Generator(n_residual_blocks=n_blocks) try: generator.load_state_dict(state_dict, strict=True) print(f"Loaded with strict=True") except: generator.load_state_dict(state_dict, strict=False) print(f"Loaded with strict=False") # Move to device BEFORE any precision changes generator = generator.to(self.device) generator.eval() # Check if model is actually on GPU print(f"Model device after .to(): {next(generator.parameters()).device}") # Skip half precision to ensure GPU usage - half precision can cause issues if self.device.type == 'cuda': print(f"Using full precision (fp32) on GPU") # Test GPU usage with torch.no_grad(): test_input = torch.randn(1, 3, 256, 256).to(self.device) _ = generator(test_input) print(f"GPU test successful") torch.cuda.empty_cache() self.loaded_generators[model_key] = generator print(f"Successfully loaded {model_info['name']} on {self.device}") return generator except Exception as e: print(f"Failed to load {model_info['name']}: {e}") traceback.print_exc() return None def apply_cyclegan_style(self, image, model_key, intensity=1.0): """Apply a CycleGAN style to an image""" if image is None or model_key not in self.cyclegan_models: return None model_info = self.cyclegan_models[model_key] generator = self.load_cyclegan_model(model_key) if generator is None: print(f"Could not load model for {model_info['name']}") return None try: # Ensure model is on GPU generator = generator.to(self.device) # Debug GPU usage if self.device.type == 'cuda': print(f"GPU Memory before style transfer: {torch.cuda.memory_allocated() / 1e9:.2f} GB") original_size = image.size w, h = image.size new_w = ((w + 31) // 32) * 32 new_h = ((h + 31) // 32) * 32 max_size = 1024 if self.device.type == 'cuda' else 512 if new_w > max_size or new_h > max_size: ratio = min(max_size / new_w, max_size / new_h) new_w = int(new_w * ratio) new_h = int(new_h * ratio) new_w = ((new_w + 31) // 32) * 32 new_h = ((new_h + 31) // 32) * 32 image_resized = image.resize((new_w, new_h), Image.LANCZOS) img_tensor = self.transform(image_resized).unsqueeze(0).to(self.device) with torch.no_grad(): # Skip half precision for now to ensure GPU usage if self.device.type == 'cuda': torch.cuda.synchronize() # Ensure GPU operations complete torch.cuda.empty_cache() output = generator(img_tensor) # Ensure output is on the same device output = output.to(self.device) if self.device.type == 'cuda': torch.cuda.synchronize() # Wait for GPU to finish output_img = self.inverse_transform(output.squeeze(0).cpu()) output_img = output_img.resize(original_size, Image.LANCZOS) if self.device.type == 'cuda': print(f"GPU Memory after style transfer: {torch.cuda.memory_allocated() / 1e9:.2f} GB") torch.cuda.empty_cache() if intensity < 1.0: output_array = np.array(output_img, dtype=np.float32) original_array = np.array(image, dtype=np.float32) blended = original_array * (1 - intensity) + output_array * intensity output_img = Image.fromarray(blended.astype(np.uint8)) return output_img except Exception as e: print(f"Error applying style {model_info['name']}: {e}") print(f"Device: {self.device}") print(f"Model device: {next(generator.parameters()).device}") traceback.print_exc() return None def train_lightweight_model(self, style_image, content_dir, model_name, epochs=30, batch_size=4, lr=1e-3, save_interval=5, style_weight=1e5, content_weight=1.0, n_residual_blocks=5, progress_callback=None): """Train a lightweight style transfer model""" model = LightweightStyleNet(n_residual_blocks=n_residual_blocks).to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) print(f"Model architecture: {n_residual_blocks} residual blocks") print(f"Training device: {self.device}") # Verify model is on GPU if self.device.type == 'cuda': print(f"Model on GPU: {next(model.parameters()).device}") print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB") # Calculate augmentation factor num_content_images = len(list(Path(content_dir).glob('*'))) if num_content_images < 5: augment_factor = 20 elif num_content_images < 10: augment_factor = 10 elif num_content_images < 20: augment_factor = 5 else: augment_factor = 1 # Create dataset with augmentation if num_content_images < 10: transform = transforms.Compose([ transforms.RandomResizedCrop(256, scale=(0.7, 1.2)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) print(f"Using heavy augmentation due to limited images ({num_content_images} provided)") else: transform = transforms.Compose([ transforms.Resize(286), transforms.RandomCrop(256), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = StyleTransferDataset(content_dir, transform=transform, augment_factor=augment_factor) print(f"Training configuration:") print(f" - Original images: {num_content_images}") print(f" - Augmentation factor: {augment_factor}x") print(f" - Total training samples: {len(dataset)}") print(f" - Residual blocks: {n_residual_blocks}") print(f" - Batch size: {int(batch_size)}") print(f" - Epochs: {epochs}") # Adjust batch size for small datasets if num_content_images == 1: if n_residual_blocks >= 9 and int(batch_size) > 1: actual_batch_size = 1 print(f"Reduced batch size to 1 for single image + {n_residual_blocks} blocks") elif int(batch_size) > 2: actual_batch_size = 2 print(f"Reduced batch size to 2 for single image training") else: actual_batch_size = min(int(batch_size), len(dataset)) else: actual_batch_size = min(int(batch_size), len(dataset)) dataloader = DataLoader(dataset, batch_size=actual_batch_size, shuffle=True, num_workers=0 if num_content_images < 10 else 2) # Prepare style image style_transform = transforms.Compose([ transforms.Resize(800), transforms.CenterCrop(768) ]) style_pil = style_transform(style_image) style_tensor = self.vgg_transform(style_pil).unsqueeze(0).to(self.device) # Create VGG features extractor for loss vgg_features = SimpleVGGFeatures().to(self.device).eval() print(f"VGG features on device: {next(vgg_features.parameters()).device}") # Extract style features once with torch.no_grad(): style_features = vgg_features(style_tensor) # Loss function perceptual_loss = PerceptualLoss(vgg_features) # Training loop model.train() total_steps = 0 for epoch in range(epochs): epoch_loss = 0 for batch_idx, content_batch in enumerate(dataloader): content_batch = content_batch.to(self.device) # Forward pass output = model(content_batch) # Ensure all tensors have the same size target_size = (256, 256) # Convert for VGG output_vgg = [] content_vgg = [] for i in range(output.size(0)): # Denormalize from [-1, 1] to [0, 1] out_img = output[i] * 0.5 + 0.5 cont_img = content_batch[i] * 0.5 + 0.5 # Ensure exact size match if out_img.shape[1:] != (target_size[0], target_size[1]): out_img = F.interpolate(out_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0) if cont_img.shape[1:] != (target_size[0], target_size[1]): cont_img = F.interpolate(cont_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0) # Normalize for VGG out_norm = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(out_img) cont_norm = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(cont_img) output_vgg.append(out_norm) content_vgg.append(cont_norm) output_vgg = torch.stack(output_vgg) content_vgg = torch.stack(content_vgg) # Ensure style tensor matches batch size and dimensions style_vgg = style_tensor.expand(output_vgg.size(0), -1, -1, -1) if style_vgg.shape[2:] != output_vgg.shape[2:]: style_vgg = F.interpolate(style_vgg, size=output_vgg.shape[2:], mode='bilinear', align_corners=False) # Calculate loss loss, content_loss, style_loss = perceptual_loss( output_vgg, content_vgg, style_vgg, content_weight=content_weight, style_weight=style_weight ) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() total_steps += 1 # Progress callback if progress_callback and total_steps % 10 == 0: progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs aug_info = f" (aug {num_content_images}→{len(dataset)})" if num_content_images < 20 else "" blocks_info = f", {n_residual_blocks} blocks" progress_callback(progress, f"Epoch {epoch+1}/{epochs}{aug_info}{blocks_info}, Loss: {loss.item():.4f}") # Save checkpoint if (epoch + 1) % int(save_interval) == 0: checkpoint_path = f'{self.models_dir}/{model_name}_epoch_{epoch+1}.pth' torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss / len(dataloader), 'n_residual_blocks': n_residual_blocks }, checkpoint_path) print(f"Saved checkpoint: {checkpoint_path}") # Save final model final_path = f'{self.models_dir}/{model_name}_final.pth' torch.save({ 'model_state_dict': model.state_dict(), 'n_residual_blocks': n_residual_blocks }, final_path) print(f"Training complete! Model saved to: {final_path}") # Add to lightweight models self.lightweight_models[model_name] = model return model def load_lightweight_model(self, model_path): """Load a trained lightweight model""" try: # Load directly to the target device state_dict = torch.load(model_path, map_location=self.device) # Check if n_residual_blocks is saved if isinstance(state_dict, dict) and 'n_residual_blocks' in state_dict: n_blocks = state_dict['n_residual_blocks'] print(f"Found saved architecture: {n_blocks} residual blocks") else: # Try to detect from state dict if 'model_state_dict' in state_dict: model_state = state_dict['model_state_dict'] else: model_state = state_dict res_block_keys = [k for k in model_state.keys() if 'res_blocks' in k and 'weight' in k] n_blocks = len(set([k.split('.')[1] for k in res_block_keys if k.startswith('res_blocks')])) or 5 print(f"Detected {n_blocks} residual blocks from model structure") # Create model with detected architecture model = LightweightStyleNet(n_residual_blocks=n_blocks).to(self.device) # Load the weights if 'model_state_dict' in state_dict: model.load_state_dict(state_dict['model_state_dict']) else: model.load_state_dict(state_dict) model.eval() # Verify model is on correct device print(f"Lightweight model loaded on: {next(model.parameters()).device}") return model except Exception as e: print(f"Error loading lightweight model: {e}") # Try with default 5 blocks try: print("Attempting to load with default 5 residual blocks...") model = LightweightStyleNet(n_residual_blocks=5).to(self.device) if model_path.endswith('.pth'): state_dict = torch.load(model_path, map_location=self.device) if 'model_state_dict' in state_dict: model.load_state_dict(state_dict['model_state_dict']) else: model.load_state_dict(state_dict) model.eval() print(f"Fallback model loaded on: {next(model.parameters()).device}") return model except: return None # Inside the StyleTransferSystem class, add these methods: def _create_linear_weight(self, width, height, overlap): """Create linear blending weights for tile edges""" weight = np.ones((height, width, 1), dtype=np.float32) if overlap > 0: # Create gradients for each edge for i in range(overlap): alpha = i / overlap # Top edge weight[i, :] *= alpha # Bottom edge weight[-i-1, :] *= alpha # Left edge weight[:, i] *= alpha # Right edge weight[:, -i-1] *= alpha return weight def _create_gaussian_weight(self, width, height, overlap): """Create Gaussian blending weights for smoother transitions""" weight = np.ones((height, width), dtype=np.float32) # Create 2D Gaussian centered in the tile y, x = np.ogrid[:height, :width] center_y, center_x = height / 2, width / 2 # Distance from center dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2) # Gaussian falloff starting from the edges max_dist = min(height, width) / 2 sigma = max_dist / 2 # Adjust for smoother/sharper transitions # Apply Gaussian only near edges edge_dist = np.minimum( np.minimum(y, height - 1 - y), np.minimum(x, width - 1 - x) ) # Weight is 1 in center, Gaussian falloff near edges weight = np.where( edge_dist < overlap, np.exp(-0.5 * ((overlap - edge_dist) / (overlap/3))**2), 1.0 ) return weight.reshape(height, width, 1) def apply_lightweight_style(self, image, model, intensity=1.0): """Apply style using a lightweight model""" if image is None or model is None: return None try: # Ensure model is on the correct device model = model.to(self.device) model.eval() original_size = image.size transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) img_tensor = transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): if self.device.type == 'cuda': torch.cuda.synchronize() output = model(img_tensor) if self.device.type == 'cuda': torch.cuda.synchronize() output_img = self.inverse_transform(output.squeeze(0).cpu()) output_img = output_img.resize(original_size, Image.LANCZOS) if intensity < 1.0: output_array = np.array(output_img, dtype=np.float32) original_array = np.array(image, dtype=np.float32) blended = original_array * (1 - intensity) + output_array * intensity output_img = Image.fromarray(blended.astype(np.uint8)) return output_img except Exception as e: print(f"Error applying lightweight style: {e}") print(f"Device: {self.device}") print(f"Model device: {next(model.parameters()).device}") return None def blend_styles(self, image, style_configs, blend_mode="additive"): """Apply multiple styles with different blending modes""" if not image or not style_configs: return image original = np.array(image, dtype=np.float32) styled_images = [] weights = [] for style_type, model_key, intensity in style_configs: if intensity <= 0: continue if style_type == 'cyclegan': styled = self.apply_cyclegan_style(image, model_key, 1.0) elif style_type == 'lightweight' and model_key in self.lightweight_models: styled = self.apply_lightweight_style(image, self.lightweight_models[model_key], 1.0) else: continue if styled: styled_images.append(np.array(styled, dtype=np.float32)) weights.append(intensity) if not styled_images: return image # Apply blending if blend_mode == "average": result = np.zeros_like(original) total_weight = sum(weights) for img, weight in zip(styled_images, weights): result += img * (weight / total_weight) elif blend_mode == "additive": result = original.copy() for img, weight in zip(styled_images, weights): transformation = img - original result = result + transformation * weight elif blend_mode == "maximum": result = original.copy() for img, weight in zip(styled_images, weights): transformation = (img - original) * weight current_diff = result - original mask = np.abs(transformation) > np.abs(current_diff) result[mask] = original[mask] + transformation[mask] elif blend_mode == "overlay": result = original.copy() for img, weight in zip(styled_images, weights): overlay = np.zeros_like(result) mask = result < 128 overlay[mask] = 2 * img[mask] * result[mask] / 255.0 overlay[~mask] = 255 - 2 * (255 - img[~mask]) * (255 - result[~mask]) / 255.0 result = result * (1 - weight) + overlay * weight else: # "screen" mode result = original.copy() for img, weight in zip(styled_images, weights): screened = 255 - ((255 - result) * (255 - img) / 255.0) if weight > 1.0: diff = screened - result result = result + diff * weight else: result = result * (1 - weight) + screened * weight return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) def apply_regional_styles(self, image, combined_mask, regions, base_style_configs=None, blend_mode="additive"): """Apply different styles to painted regions using a combined mask""" if not regions: if base_style_configs: return self.blend_styles(image, base_style_configs, blend_mode) return image original_size = image.size result = np.array(image, dtype=np.float32) # Apply base style if provided if base_style_configs: base_styled = self.blend_styles(image, base_style_configs, blend_mode) result = np.array(base_styled, dtype=np.float32) # Resize mask to match original image if needed if combined_mask is not None and combined_mask.shape[:2] != (original_size[1], original_size[0]): # Resize the combined mask to match the original image combined_mask_pil = Image.fromarray(combined_mask.astype(np.uint8)) combined_mask_resized = combined_mask_pil.resize(original_size, Image.NEAREST) combined_mask = np.array(combined_mask_resized) # Apply each region for i, region in enumerate(regions): if region['style'] is None: continue # Get model key for this region's style model_key = None for key, info in self.cyclegan_models.items(): if info['name'] == region['style']: model_key = key break if not model_key: continue # Apply style to whole image style_configs = [('cyclegan', model_key, region['intensity'])] styled = self.blend_styles(image, style_configs, blend_mode) styled_array = np.array(styled, dtype=np.float32) # Create mask for this region from combined mask if combined_mask is not None: # Region masks are identified by their color index region_mask = (combined_mask == (i + 1)).astype(np.float32) # Ensure mask has same shape as image if len(region_mask.shape) == 2: region_mask_3ch = np.stack([region_mask] * 3, axis=2) else: region_mask_3ch = region_mask # Blend using mask result = result * (1 - region_mask_3ch) + styled_array * region_mask_3ch return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) def train_adain_model(self, style_images, content_dir, model_name, epochs=30, batch_size=4, lr=1e-4, save_interval=5, style_weight=10.0, content_weight=1.0, progress_callback=None): """Train an AdaIN-based style transfer model""" model = AdaINStyleTransfer().to(self.device) optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr) # Add learning rate scheduler scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8) print(f"Training AdaIN model") print(f"Training device: {self.device}") # Verify model is on GPU if self.device.type == 'cuda': print(f"Model on GPU: {next(model.decoder.parameters()).device}") print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB") # Prepare style images - INCREASED SIZE style_transform = transforms.Compose([ transforms.Resize(600), # Increased from 512 transforms.RandomCrop(512), # Increased from 256 transforms.RandomHorizontalFlip(p=0.5), # Add augmentation transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) style_tensors = [] # Create multiple augmented versions of each style image for style_img in style_images: # Generate 5 augmented versions per style image for _ in range(5): style_tensor = style_transform(style_img).unsqueeze(0).to(self.device) style_tensors.append(style_tensor) print(f"Created {len(style_tensors)} augmented style samples from {len(style_images)} images") # Prepare content dataset - INCREASED SIZE content_transform = transforms.Compose([ transforms.Resize(600), # Increased from 512 transforms.RandomCrop(512), # Increased from 256 transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = StyleTransferDataset(content_dir, transform=content_transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) print(f"Training configuration:") print(f" - Style images: {len(style_tensors)}") print(f" - Content images: {len(dataset)}") print(f" - Batch size: {batch_size}") print(f" - Epochs: {epochs}") print(f" - Training resolution: 512x512") # Updated # Loss network (VGG for perceptual loss) - USE MULTIPLE LAYERS class MultiLayerVGG(nn.Module): def __init__(self): super().__init__() vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1 self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1 self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1 self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1 for param in self.parameters(): param.requires_grad = False def forward(self, x): h1 = self.slice1(x) h2 = self.slice2(h1) h3 = self.slice3(h2) h4 = self.slice4(h3) return [h1, h2, h3, h4] loss_network = MultiLayerVGG().to(self.device).eval() mse_loss = nn.MSELoss() # Training loop model.train() model.encoder.eval() # Keep encoder frozen total_steps = 0 # Adjust style weight for better quality actual_style_weight = style_weight * 10 # Multiply by 10 for better style transfer for epoch in range(epochs): epoch_loss = 0 for batch_idx, content_batch in enumerate(dataloader): content_batch = content_batch.to(self.device) # Randomly select style images for this batch batch_style = [] for _ in range(content_batch.size(0)): style_idx = np.random.randint(0, len(style_tensors)) batch_style.append(style_tensors[style_idx]) batch_style = torch.cat(batch_style, dim=0) # Forward pass output = model(content_batch, batch_style) # Multi-layer content and style loss with torch.no_grad(): content_feats = loss_network(content_batch) style_feats = loss_network(batch_style) output_feats = loss_network(output) # Content loss - only from relu4_1 content_loss = mse_loss(output_feats[-1], content_feats[-1]) # Style loss - from multiple layers style_loss = 0 style_weights = [0.2, 0.3, 0.5, 1.0] # Give more weight to higher layers def gram_matrix(feat): b, c, h, w = feat.size() feat = feat.view(b, c, h * w) gram = torch.bmm(feat, feat.transpose(1, 2)) return gram / (c * h * w) for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)): output_gram = gram_matrix(output_feat) style_gram = gram_matrix(style_feat) style_loss += weight * mse_loss(output_gram, style_gram) style_loss /= len(style_weights) # Total loss loss = content_weight * content_loss + actual_style_weight * style_loss # Backward pass optimizer.zero_grad() loss.backward() # Gradient clipping for stability torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0) optimizer.step() epoch_loss += loss.item() total_steps += 1 # Progress callback if progress_callback and total_steps % 10 == 0: progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs progress_callback(progress, f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} " f"(Content: {content_loss.item():.4f}, Style: {style_loss.item():.4f})") # Step scheduler scheduler.step() # Save checkpoint if (epoch + 1) % save_interval == 0: checkpoint_path = f'{self.models_dir}/{model_name}_epoch_{epoch+1}.pth' torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': epoch_loss / len(dataloader), 'model_type': 'adain' }, checkpoint_path) print(f"Saved checkpoint: {checkpoint_path}") # Save final model final_path = f'{self.models_dir}/{model_name}_final.pth' torch.save({ 'model_state_dict': model.state_dict(), 'model_type': 'adain' }, final_path) print(f"Training complete! Model saved to: {final_path}") # Add to lightweight models self.lightweight_models[model_name] = model return model # Update these methods in your StyleTransferSystem class: def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False): """Apply AdaIN-based style transfer with optional tiling""" # Use tiling for large images to maintain quality if use_tiling and (content_image.width > 768 or content_image.height > 768): return self.apply_adain_style_tiled( content_image, style_image, model, alpha, tile_size=512, # Increased from 256 overlap=64, # Increased overlap blend_mode='gaussian' ) if content_image is None or style_image is None or model is None: return None try: model = model.to(self.device) model.eval() original_size = content_image.size # Use higher resolution - find optimal size while maintaining aspect ratio max_dim = 768 # Increased from 256 w, h = content_image.size if w > h: new_w = min(w, max_dim) new_h = int(h * new_w / w) else: new_h = min(h, max_dim) new_w = int(w * new_h / h) # Ensure dimensions are divisible by 8 for better compatibility new_w = (new_w // 8) * 8 new_h = (new_h // 8) * 8 # Transform for AdaIN (VGG normalization) transform = transforms.Compose([ transforms.Resize((new_h, new_w)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) content_tensor = transform(content_image).unsqueeze(0).to(self.device) style_tensor = transform(style_image).unsqueeze(0).to(self.device) with torch.no_grad(): output = model(content_tensor, style_tensor, alpha=alpha) # Denormalize output = output.squeeze(0).cpu() output = output * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) output = output + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) output = torch.clamp(output, 0, 1) # Convert to PIL output_img = transforms.ToPILImage()(output) output_img = output_img.resize(original_size, Image.LANCZOS) return output_img except Exception as e: print(f"Error applying AdaIN style: {e}") traceback.print_exc() return None def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0, tile_size=256, overlap=32, blend_mode='linear'): """ Apply AdaIN style transfer using tiling for high-quality results. Processes image in overlapping tiles to maintain quality. """ if content_image is None or style_image is None or model is None: return None try: model = model.to(self.device) model.eval() # INCREASED TILE SIZE FOR BETTER QUALITY tile_size = 512 # Override input to use 512 overlap = 64 # Increase overlap proportionally # Prepare transforms transform = transforms.Compose([ transforms.Resize((tile_size, tile_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Process style image once (at tile size) style_tensor = transform(style_image).unsqueeze(0).to(self.device) # Get dimensions w, h = content_image.size # Calculate tile positions with overlap stride = tile_size - overlap tiles_x = list(range(0, w - tile_size + 1, stride)) tiles_y = list(range(0, h - tile_size + 1, stride)) # Ensure we cover the entire image if not tiles_x or tiles_x[-1] + tile_size < w: tiles_x.append(max(0, w - tile_size)) if not tiles_y or tiles_y[-1] + tile_size < h: tiles_y.append(max(0, h - tile_size)) # If image is smaller than tile size, just process normally if w <= tile_size and h <= tile_size: return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False) print(f"Processing {len(tiles_x) * len(tiles_y)} tiles of size {tile_size}x{tile_size}") # Initialize output and weight arrays output_array = np.zeros((h, w, 3), dtype=np.float32) weight_array = np.zeros((h, w, 1), dtype=np.float32) # Process each tile with torch.no_grad(): for y_idx, y in enumerate(tiles_y): for x_idx, x in enumerate(tiles_x): # Extract tile tile = content_image.crop((x, y, x + tile_size, y + tile_size)) # Transform tile tile_tensor = transform(tile).unsqueeze(0).to(self.device) # Apply AdaIN to tile styled_tensor = model(tile_tensor, style_tensor, alpha=alpha) # Denormalize styled_tensor = styled_tensor.squeeze(0).cpu() denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) styled_tensor = styled_tensor * denorm_std + denorm_mean styled_tensor = torch.clamp(styled_tensor, 0, 1) # Convert to numpy styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255 # Create weight mask for blending - use gaussian by default for better quality weight = self._create_gaussian_weight(tile_size, tile_size, overlap) # Add to output with weights output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight weight_array[y:y+tile_size, x:x+tile_size] += weight # Normalize by weights output_array = output_array / (weight_array + 1e-8) output_array = np.clip(output_array, 0, 255).astype(np.uint8) return Image.fromarray(output_array) except Exception as e: print(f"Error in tiled AdaIN processing: {e}") traceback.print_exc() # Fallback to standard processing return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False) # =========================== # HELPER FUNCTIONS # =========================== def resize_image_for_display(image, max_width=800, max_height=600): """Resize image for display while maintaining aspect ratio""" width, height = image.size # Calculate scaling factor width_scale = max_width / width height_scale = max_height / height scale = min(width_scale, height_scale) # Only scale down, not up if scale < 1: new_width = int(width * scale) new_height = int(height * scale) return image.resize((new_width, new_height), Image.LANCZOS) return image def combine_region_masks(canvas_results, canvas_size): """Combine multiple region masks into a single mask with different values for each region""" combined_mask = np.zeros(canvas_size[:2], dtype=np.uint8) for i, canvas_data in enumerate(canvas_results): if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None: # Extract alpha channel as mask mask = canvas_data.image_data[:, :, 3] > 0 # Assign region index (1-based) to mask combined_mask[mask] = i + 1 return combined_mask def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False): """Apply AdaIN style transfer to a painted region only""" if content_image is None or style_image is None or model is None: return None try: # Get the mask from canvas if canvas_result is None or canvas_result.image_data is None: # No mask painted, apply to whole image return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) # Extract mask from canvas mask_data = canvas_result.image_data[:, :, 3] # Alpha channel mask = mask_data > 0 # Resize mask to match original image size original_size = content_image.size display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0]) if original_size != display_size: # Convert mask to PIL image for resizing mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L') mask_pil = mask_pil.resize(original_size, Image.NEAREST) mask = np.array(mask_pil) > 128 # Apply feathering to mask edges if requested if feather_radius > 0: from scipy.ndimage import gaussian_filter mask_float = mask.astype(np.float32) mask_float = gaussian_filter(mask_float, sigma=feather_radius) mask_float = np.clip(mask_float, 0, 1) else: mask_float = mask.astype(np.float32) # Apply style to entire image with tiling option styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) if styled_full is None: return None # Blend original and styled based on mask original_array = np.array(content_image, dtype=np.float32) styled_array = np.array(styled_full, dtype=np.float32) # Expand mask to 3 channels mask_3ch = np.stack([mask_float] * 3, axis=2) # Blend result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch result_array = np.clip(result_array, 0, 255).astype(np.uint8) return Image.fromarray(result_array) except Exception as e: print(f"Error applying regional AdaIN style: {e}") traceback.print_exc() return None # =========================== # INITIALIZE SYSTEM AND API # =========================== @st.cache_resource def load_system(): return StyleTransferSystem() @st.cache_resource def get_unsplash_api(): return UnsplashAPI() system = load_system() unsplash = get_unsplash_api() # Get style choices style_choices = sorted([info['name'] for info in system.cyclegan_models.values()]) # =========================== # STREAMLIT APP # =========================== # Main app st.title("Style Transfer") st.markdown("Image and video style transfer with CycleGAN and custom training capabilities") # Sidebar for global settings with st.sidebar: st.header("Settings") # GPU status if torch.cuda.is_available(): gpu_info = torch.cuda.get_device_properties(0) st.success(f"GPU: {gpu_info.name}") # Show memory usage total_memory = gpu_info.total_memory / 1e9 used_memory = torch.cuda.memory_allocated() / 1e9 free_memory = total_memory - used_memory col1, col2 = st.columns(2) with col1: st.metric("Total Memory", f"{total_memory:.2f} GB") with col2: st.metric("Used Memory", f"{used_memory:.2f} GB") # Memory usage bar memory_percentage = (used_memory / total_memory) * 100 st.progress(memory_percentage / 100) st.caption(f"Free: {free_memory:.2f} GB ({100-memory_percentage:.1f}%)") # Check if GPU is actually being used if used_memory < 0.1: st.warning("GPU detected but not in use. Models may be running on CPU.") # Force GPU button if st.button("Force GPU Reset"): torch.cuda.empty_cache() torch.cuda.synchronize() st.rerun() else: st.warning("Running on CPU (GPU not available)") st.caption("For faster processing, use a GPU-enabled environment") st.markdown("---") st.markdown("### Quick Guide") st.markdown(""" - **Style Transfer**: Apply artistic styles to images - **Regional Transform**: Paint areas for local effects - **Video Processing**: Apply styles to videos - **Train Custom**: Create your own style models - **Batch Process**: Process multiple images """) # Unsplash API status st.markdown("---") if unsplash.access_key: st.success("Unsplash API Connected") else: st.info("Add Unsplash API key for image search") # Debug mode st.markdown("---") if st.checkbox("🐛 Debug Mode"): st.code(f""" Device: {device} CUDA Available: {torch.cuda.is_available()} CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'} PyTorch Version: {torch.__version__} Models Loaded: {len(system.loaded_generators)} """) # Main tabs tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([ "Style Transfer", "Regional Transform", "Video Processing", "Train Custom Style", "Batch Processing", "Documentation" ]) # TAB 1: Style Transfer (with Unsplash integration) with tab1: # Unsplash Search Section with st.expander("Search Unsplash for Images", expanded=False): if not unsplash.access_key: st.info(""" To enable Unsplash search: 1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers) 2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY` """) else: search_col1, search_col2, search_col3 = st.columns([3, 1, 1]) with search_col1: search_query = st.text_input("Search for images", placeholder="e.g., landscape, portrait, abstract art") with search_col2: orientation = st.selectbox("Orientation", ["all", "landscape", "portrait", "squarish"]) with search_col3: search_button = st.button("Search", use_container_width=True) # Random photos button if st.button("Get Random Photos"): with st.spinner("Loading random photos..."): results, error = unsplash.get_random_photos(count=12) if error: st.error(f"Error: {error}") elif results: # Handle both single photo and array of photos photos = results if isinstance(results, list) else [results] st.session_state['unsplash_results'] = photos st.success(f"Loaded {len(photos)} random photos") # Search functionality if search_button and search_query: with st.spinner(f"Searching for '{search_query}'..."): orientation_param = None if orientation == "all" else orientation results, error = unsplash.search_photos(search_query, per_page=12, orientation=orientation_param) if error: st.error(f"Error: {error}") elif results and results.get('results'): st.session_state['unsplash_results'] = results['results'] st.success(f"Found {results['total']} images") else: st.info("No images found. Try a different search term.") # Display results if 'unsplash_results' in st.session_state and st.session_state['unsplash_results']: st.markdown("### Search Results") # Display in a 4-column grid cols = st.columns(4) for idx, photo in enumerate(st.session_state['unsplash_results'][:12]): with cols[idx % 4]: # Show thumbnail st.image(photo['urls']['thumb'], use_column_width=True) # Photo info st.caption(f"By {photo['user']['name']}") # Use button if st.button("Use This", key=f"use_unsplash_{photo['id']}"): with st.spinner("Loading image..."): # Download regular size img = unsplash.download_photo(photo['urls']['regular']) if img: # Store in session state st.session_state['current_image'] = img st.session_state['image_source'] = f"Unsplash: {photo['user']['name']}" st.session_state['unsplash_photo'] = photo # Trigger download tracking (required by Unsplash) if 'links' in photo and 'download_location' in photo['links']: unsplash.trigger_download(photo['links']['download_location']) st.success("Image loaded!") st.rerun() col1, col2 = st.columns(2) with col1: st.header("Input") # Image source selection image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True) # Initialize input_image to None input_image = None if image_source == "Upload": uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg']) if uploaded_file: input_image = Image.open(uploaded_file).convert('RGB') st.session_state['current_image'] = input_image st.session_state['image_source'] = "Uploaded" else: # Handle Unsplash selection if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): input_image = st.session_state['current_image'] else: st.info("Search for an image above") if input_image: # Display the image display_img = resize_image_for_display(input_image, max_width=600, max_height=400) st.image(display_img, caption=st.session_state.get('image_source', 'Image'), use_column_width=True) # Attribution for Unsplash images if 'unsplash_photo' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): photo = st.session_state['unsplash_photo'] st.markdown(f"Photo by [{photo['user']['name']}]({photo['user']['links']['html']}) on [Unsplash]({photo['links']['html']})") st.subheader("Style Configuration") # Up to 3 styles num_styles = st.number_input("Number of styles to apply", 1, 3, 1) style_configs = [] for i in range(num_styles): with st.expander(f"Style {i+1}", expanded=(i==0)): style = st.selectbox(f"Select style", style_choices, key=f"style_{i}") intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"intensity_{i}") if style and intensity > 0: model_key = None for key, info in system.cyclegan_models.items(): if info['name'] == style: model_key = key break if model_key: style_configs.append(('cyclegan', model_key, intensity)) blend_mode = st.selectbox("Blend Mode", ["additive", "average", "maximum", "overlay", "screen"], index=0) if st.button("Apply Styles", type="primary", use_container_width=True): if style_configs: with st.spinner("Applying styles..."): progress_bar = st.progress(0) status_text = st.empty() # Process with progress updates for i, (_, key, intensity) in enumerate(style_configs): model_name = system.cyclegan_models[key]['name'] progress = (i + 1) / len(style_configs) progress_bar.progress(progress) status_text.text(f"Applying {model_name}...") result = system.blend_styles(input_image, style_configs, blend_mode) st.session_state['last_result'] = result st.session_state['last_style_configs'] = style_configs progress_bar.empty() status_text.empty() with col2: st.header("Result") if 'last_result' in st.session_state: st.image(st.session_state['last_result'], caption="Styled Image", use_column_width=True) # Download button buf = io.BytesIO() st.session_state['last_result'].save(buf, format='PNG') st.download_button( label="Download Result", data=buf.getvalue(), file_name=f"styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png", mime="image/png" ) # TAB 2: Regional Transform with tab2: st.header("Regional Style Transform") st.markdown("Paint different regions to apply different styles locally") # Initialize session state if 'regions' not in st.session_state: st.session_state.regions = [] if 'canvas_results' not in st.session_state: st.session_state.canvas_results = {} if 'regional_image_original' not in st.session_state: st.session_state.regional_image_original = None if 'canvas_ready' not in st.session_state: st.session_state.canvas_ready = True if 'last_applied_regions' not in st.session_state: st.session_state.last_applied_regions = None if 'canvas_key_base' not in st.session_state: st.session_state.canvas_key_base = 0 col1, col2 = st.columns([2, 3]) # Define variables at the top level of tab2 use_base = False base_style = None base_intensity = 1.0 regional_blend_mode = "additive" with col1: # Image source selection regional_image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True, key="regional_image_source") if regional_image_source == "Upload": uploaded_regional = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'], key="regional_upload") if uploaded_regional: # Load and store original image regional_image_original = Image.open(uploaded_regional).convert('RGB') st.session_state.regional_image_original = regional_image_original else: # Use Unsplash image if available if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): st.session_state.regional_image_original = st.session_state['current_image'] st.success("Using Unsplash image") else: st.info("Please search and select an image from the Style Transfer tab first") if st.session_state.regional_image_original: # Display the original image display_img = resize_image_for_display(st.session_state.regional_image_original, max_width=400, max_height=300) st.image(display_img, caption="Original Image", use_column_width=True) st.subheader("Define Regions") # Base style (optional) with st.expander("Base Style (Optional)", expanded=False): use_base = st.checkbox("Apply base style to entire image") if use_base: base_style = st.selectbox("Base style", style_choices, key="base_style") base_intensity = st.slider("Base intensity", 0.0, 2.0, 1.0, key="base_intensity") # Region management col_btn1, col_btn2, col_btn3 = st.columns(3) with col_btn1: if st.button("Add Region", use_container_width=True): new_region = { 'id': len(st.session_state.regions), 'style': style_choices[0] if style_choices else None, 'intensity': 1.0, 'color': f"hsla({len(st.session_state.regions) * 60}, 70%, 50%, 0.5)" } st.session_state.regions.append(new_region) st.session_state.canvas_ready = True st.rerun() with col_btn2: if st.button("Clear All", use_container_width=True): st.session_state.regions = [] st.session_state.canvas_results = {} if 'regional_result' in st.session_state: del st.session_state['regional_result'] st.session_state.canvas_ready = True st.session_state.canvas_key_base = 0 st.rerun() with col_btn3: if st.button("Reset Result", use_container_width=True): if 'regional_result' in st.session_state: del st.session_state['regional_result'] st.session_state.canvas_ready = True st.rerun() # Configure each region for i, region in enumerate(st.session_state.regions): with st.expander(f"Region {i+1} - {region.get('style', 'None')}", expanded=(i == len(st.session_state.regions) - 1)): col_a, col_b = st.columns(2) with col_a: new_style = st.selectbox( "Style", style_choices, key=f"region_style_{i}", index=style_choices.index(region['style']) if region['style'] in style_choices else 0 ) region['style'] = new_style with col_b: region['intensity'] = st.slider( "Intensity", 0.0, 2.0, region.get('intensity', 1.0), key=f"region_intensity_{i}" ) if st.button(f"Remove Region {i+1}", key=f"remove_region_{i}"): # Remove the region st.session_state.regions.pop(i) # Rebuild canvas results with proper indices old_canvas_results = st.session_state.canvas_results.copy() st.session_state.canvas_results = {} for old_idx, result in old_canvas_results.items(): if old_idx < i: # Keep results before removed index st.session_state.canvas_results[old_idx] = result elif old_idx > i: # Shift results after removed index down by 1 st.session_state.canvas_results[old_idx - 1] = result st.session_state.canvas_ready = True st.session_state.canvas_key_base += 1 st.rerun() # Blend mode regional_blend_mode = st.selectbox("Blend Mode", ["additive", "average", "maximum", "overlay", "screen"], index=0, key="regional_blend") with col2: if st.session_state.regions and st.session_state.regional_image_original: st.subheader("Paint Regions") # Show workflow status if 'regional_result' in st.session_state: if st.session_state.canvas_ready: st.success("Edit Mode - Paint your regions and click 'Apply Regional Styles' when ready") else: st.info("Preview Mode - Click 'Continue Editing' to modify regions") else: st.info("Paint on the canvas below to define regions for each style") # Check if we're in edit mode if not st.session_state.canvas_ready: # Show a preview of the painted regions if 'regional_result' in st.session_state: st.subheader("Current Result") result_display = resize_image_for_display(st.session_state['regional_result'], max_width=600, max_height=400) st.image(result_display, caption="Applied Styles", use_column_width=True) # Create display image display_image = resize_image_for_display(st.session_state.regional_image_original, max_width=600, max_height=400) display_width, display_height = display_image.size # Info message st.info(f"Image resized to {display_width}x{display_height} for display. Original resolution will be used for processing.") # Get current region current_region_idx = st.selectbox( "Select region to paint", range(len(st.session_state.regions)), format_func=lambda x: f"Region {x+1}: {st.session_state.regions[x].get('style', 'None')}" ) current_region = st.session_state.regions[current_region_idx] col_draw1, col_draw2, col_draw3 = st.columns(3) with col_draw1: brush_size = st.slider("Brush Size", 1, 50, 15) with col_draw2: drawing_mode = st.selectbox("Tool", ["freedraw", "line", "rect", "circle"]) with col_draw3: if st.button("Clear This Region"): if current_region_idx in st.session_state.canvas_results: del st.session_state.canvas_results[current_region_idx] st.session_state.canvas_ready = True st.rerun() # Create combined background with all previous regions background_with_regions = display_image.copy() draw = ImageDraw.Draw(background_with_regions, 'RGBA') # Draw all regions on the background for i, region in enumerate(st.session_state.regions): if i in st.session_state.canvas_results: canvas_data = st.session_state.canvas_results[i] if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None: # Extract mask from canvas data mask = canvas_data.image_data[:, :, 3] > 0 # Create colored overlay for this region # Parse HSLA color more carefully color_str = region['color'].replace('hsla(', '').replace(')', '') color_parts = color_str.split(',') hue = int(color_parts[0]) # Convert HSL to RGB (simplified - assumes 70% saturation, 50% lightness) r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 0.7) color = (int(r*255), int(g*255), int(b*255)) opacity = 128 if i != current_region_idx else 200 # Draw mask on background for y in range(mask.shape[0]): for x in range(mask.shape[1]): if mask[y, x]: draw.point((x, y), fill=color + (opacity,)) # Canvas for current region stroke_color = current_region['color'].replace('0.5)', '0.8)') # Get initial drawing for current region initial_drawing = None if current_region_idx in st.session_state.canvas_results: canvas_data = st.session_state.canvas_results[current_region_idx] if canvas_data is not None and hasattr(canvas_data, 'json_data'): initial_drawing = canvas_data.json_data canvas_result = st_canvas( fill_color=stroke_color, stroke_width=brush_size, stroke_color=stroke_color, background_image=background_with_regions, update_streamlit=True, height=display_height, width=display_width, drawing_mode=drawing_mode, display_toolbar=True, initial_drawing=initial_drawing, key=f"regional_canvas_{current_region_idx}_{brush_size}_{drawing_mode}" ) # Save canvas result if canvas_result: st.session_state.canvas_results[current_region_idx] = canvas_result # Apply button if st.button("Apply Regional Styles", type="primary", use_container_width=True): with st.spinner("Applying regional styles..."): # Create combined mask from all canvas results combined_mask = combine_region_masks( [st.session_state.canvas_results.get(i) for i in range(len(st.session_state.regions))], (display_height, display_width) ) # Prepare base style configs if enabled base_configs = None if use_base and base_style: base_key = None for key, info in system.cyclegan_models.items(): if info['name'] == base_style: base_key = key break if base_key: base_configs = [('cyclegan', base_key, base_intensity)] # Apply regional styles using original image result = system.apply_regional_styles( st.session_state.regional_image_original, # Use original resolution combined_mask, st.session_state.regions, base_configs, regional_blend_mode ) st.session_state['regional_result'] = result # Show result with fixed size if 'regional_result' in st.session_state: st.subheader("Result") # Add display size control display_size = st.slider("Display Size", 300, 800, 600, 50, key="regional_display_size") # Fixed size display result_display = resize_image_for_display( st.session_state['regional_result'], max_width=display_size, max_height=display_size ) st.image(result_display, caption="Regional Styled Image") # Show actual dimensions st.caption(f"Original size: {st.session_state['regional_result'].size[0]}x{st.session_state['regional_result'].size[1]} pixels") # Download button buf = io.BytesIO() st.session_state['regional_result'].save(buf, format='PNG') st.download_button( label="Download Result", data=buf.getvalue(), file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png", mime="image/png" ) # TAB 3: Video Processing with tab3: st.header("Video Processing") if not VIDEO_PROCESSING_AVAILABLE: st.warning(""" Video processing requires OpenCV to be installed. To enable video processing, add `opencv-python` to your requirements.txt """) else: col1, col2 = st.columns(2) with col1: video_file = st.file_uploader("Upload Video", type=['mp4', 'avi', 'mov']) if video_file: st.video(video_file) st.subheader("Style Configuration") # Style selection (up to 2 for videos) video_styles = [] for i in range(2): with st.expander(f"Style {i+1}", expanded=(i==0)): style = st.selectbox(f"Select style", style_choices, key=f"video_style_{i}") intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"video_intensity_{i}") if style and intensity > 0: model_key = None for key, info in system.cyclegan_models.items(): if info['name'] == style: model_key = key break if model_key: video_styles.append(('cyclegan', model_key, intensity)) video_blend_mode = st.selectbox("Blend Mode", ["additive", "average", "maximum", "overlay", "screen"], index=0, key="video_blend") if st.button("Process Video", type="primary", use_container_width=True): if video_styles: with st.spinner("Processing video..."): progress_bar = st.progress(0) status_text = st.empty() def progress_callback(p, msg): progress_bar.progress(p) status_text.text(msg) # Save uploaded file temporarily temp_input = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(video_file.name)[1]) temp_input.write(video_file.read()) temp_input.close() # Process video output_path = system.video_processor.process_video( temp_input.name, video_styles, video_blend_mode, progress_callback ) if output_path and os.path.exists(output_path): # Read the video file immediately try: with open(output_path, 'rb') as f: video_bytes = f.read() # Determine file extension file_ext = os.path.splitext(output_path)[1].lower() # Store in session state st.session_state['video_result_bytes'] = video_bytes st.session_state['video_result_ext'] = file_ext st.session_state['video_result_available'] = True st.session_state['video_is_mp4'] = (file_ext == '.mp4') st.success(f"Video processing complete! Format: {file_ext.upper()}") # Clean up files try: os.unlink(output_path) except: pass except Exception as e: st.error(f"Failed to read processed video: {str(e)}") st.session_state['video_result_available'] = False else: st.error("Failed to process video. Please try a different video or reduce the resolution.") st.session_state['video_result_available'] = False # Cleanup input file try: os.unlink(temp_input.name) except: pass progress_bar.empty() status_text.empty() else: st.warning("Please select at least one style") with col2: st.header("Result") if st.session_state.get('video_result_available', False) and 'video_result_bytes' in st.session_state: try: file_ext = st.session_state.get('video_result_ext', '.mp4') video_bytes = st.session_state['video_result_bytes'] # File info file_size_mb = len(video_bytes) / (1024 * 1024) # Try to display video if st.session_state.get('video_is_mp4', False): # For MP4, should work in browser st.video(video_bytes) st.success(f"Video ready! Size: {file_size_mb:.2f} MB") else: # For non-MP4, show info and download st.info(f"Video format ({file_ext}) may not play in browser. Please download to view.") # Show preview image if possible try: # Try to extract a frame for preview temp_preview = tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) temp_preview.write(video_bytes) temp_preview.close() cap = cv2.VideoCapture(temp_preview.name) ret, frame = cap.read() if ret: # Convert frame to RGB and display rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) st.image(rgb_frame, caption="Video Preview (First Frame)", use_column_width=True) cap.release() os.unlink(temp_preview.name) except: pass # Always provide download button mime_types = { '.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime' } mime_type = mime_types.get(file_ext, 'application/octet-stream') col_dl1, col_dl2 = st.columns(2) with col_dl1: st.download_button( label=f"Download Video ({file_ext.upper()})", data=video_bytes, file_name=f"styled_video_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}{file_ext}", mime=mime_type, use_container_width=True ) with col_dl2: if st.button("Clear Result", use_container_width=True): del st.session_state['video_result_bytes'] st.session_state['video_result_available'] = False if 'video_result_ext' in st.session_state: del st.session_state['video_result_ext'] if 'video_is_mp4' in st.session_state: del st.session_state['video_is_mp4'] st.rerun() # Info about playback if not st.session_state.get('video_is_mp4', False): st.caption("For best compatibility, download and use VLC or another video player.") except Exception as e: st.error(f"Error displaying video: {str(e)}") # Emergency download button if 'video_result_bytes' in st.session_state: st.download_button( label="📥 Download Video (Error occurred)", data=st.session_state['video_result_bytes'], file_name=f"styled_video_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4", mime="application/octet-stream" ) elif st.session_state.get('video_result_available', False): st.warning("Video data not found. Please process the video again.") if st.button("Clear State"): st.session_state['video_result_available'] = False st.rerun() # TAB 4: Training with AdaIN and Regional Application # TAB 4: Training with AdaIN and Regional Application with tab4: st.header("Train Custom Style with AdaIN") st.markdown("Train your own style transfer model using Adaptive Instance Normalization") # Initialize session state for content images if 'content_images_list' not in st.session_state: st.session_state.content_images_list = [] if 'adain_canvas_result' not in st.session_state: st.session_state.adain_canvas_result = None if 'adain_test_image' not in st.session_state: st.session_state.adain_test_image = None col1, col2, col3 = st.columns([1, 1, 1]) with col1: st.subheader("Style Images") style_imgs = st.file_uploader("Upload 1-5 style images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True, key="train_style_adain") if style_imgs: st.markdown(f"**{len(style_imgs)} style image(s) uploaded**") # Display style images in a grid style_cols = st.columns(min(len(style_imgs), 3)) for idx, style_img in enumerate(style_imgs[:3]): with style_cols[idx % 3]: img = Image.open(style_img).convert('RGB') st.image(img, caption=f"Style {idx+1}", use_column_width=True) if len(style_imgs) > 3: st.caption(f"... and {len(style_imgs) - 3} more") with col2: st.subheader("Content Images") content_imgs = st.file_uploader("Upload content images (10-50 recommended)", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True, key="train_content_adain") if content_imgs: st.markdown(f"**{len(content_imgs)} content image(s) uploaded**") # Store content images in session state for later use st.session_state.content_images_list = content_imgs # Display content images in a grid content_cols = st.columns(min(len(content_imgs), 3)) for idx, content_img in enumerate(content_imgs[:3]): with content_cols[idx % 3]: img = Image.open(content_img).convert('RGB') st.image(img, caption=f"Content {idx+1}", use_column_width=True) if len(content_imgs) > 3: st.caption(f"... and {len(content_imgs) - 3} more") with col3: st.subheader("Training Settings") model_name = st.text_input("Model Name", value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # IMPROVED DEFAULT VALUES epochs = st.slider("Training Epochs", 10, 100, 50, 5) # Increased default batch_size = st.slider("Batch Size", 1, 8, 4) learning_rate = st.number_input("Learning Rate", 0.00001, 0.001, 0.0001, format="%.5f") with st.expander("Advanced Settings"): # MUCH HIGHER STYLE WEIGHT BY DEFAULT style_weight = st.number_input("Style Weight", 1.0, 1000.0, 100.0, 10.0) content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1) save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20, 10, 5) st.info("💡 **Pro tip**: For better quality, use Style Weight 100-500x higher than Content Weight") st.markdown("---") # Training button if st.button("Start AdaIN Training", type="primary", use_container_width=True): if style_imgs and content_imgs: if len(content_imgs) < 10: st.warning("For best results, use at least 10 content images") with st.spinner("Training AdaIN model..."): progress_bar = st.progress(0) status_text = st.empty() def progress_callback(p, msg): progress_bar.progress(p) status_text.text(msg) # Create temp directory for content images temp_content_dir = f'/tmp/content_images_{uuid.uuid4().hex}' os.makedirs(temp_content_dir, exist_ok=True) # Save content images for idx, img_file in enumerate(content_imgs): img = Image.open(img_file).convert('RGB') img.save(os.path.join(temp_content_dir, f'content_{idx}.jpg')) # Load style images style_images = [] for style_file in style_imgs: style_img = Image.open(style_file).convert('RGB') style_images.append(style_img) # IMPROVED TRAINING FUNCTION # Multi-layer VGG loss for better quality class MultiLayerVGG(nn.Module): def __init__(self): super().__init__() vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1 self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1 self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1 self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1 for param in self.parameters(): param.requires_grad = False def forward(self, x): h1 = self.slice1(x) h2 = self.slice2(h1) h3 = self.slice3(h2) h4 = self.slice4(h3) return [h1, h2, h3, h4] # Create model model = AdaINStyleTransfer().to(system.device) optimizer = torch.optim.Adam(model.decoder.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8) print(f"Training AdaIN model at 512x512 resolution") print(f"Training device: {system.device}") # Prepare style images - LARGER SIZE style_transform = transforms.Compose([ transforms.Resize(600), # Increased size transforms.RandomCrop(512), # Larger crops transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) style_tensors = [] # Create multiple augmented versions for style_img in style_images: for _ in range(5): # 5 augmented versions per style style_tensor = style_transform(style_img).unsqueeze(0).to(system.device) style_tensors.append(style_tensor) # Prepare content dataset - LARGER SIZE content_transform = transforms.Compose([ transforms.Resize(600), transforms.RandomCrop(512), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = StyleTransferDataset(temp_content_dir, transform=content_transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # Multi-layer loss network loss_network = MultiLayerVGG().to(system.device).eval() mse_loss = nn.MSELoss() # Training loop model.train() model.encoder.eval() total_steps = 0 # Multiply style weight for better results actual_style_weight = style_weight * 10 for epoch in range(epochs): epoch_loss = 0 epoch_content_loss = 0 epoch_style_loss = 0 for batch_idx, content_batch in enumerate(dataloader): content_batch = content_batch.to(system.device) # Randomly select style images batch_style = [] for _ in range(content_batch.size(0)): style_idx = np.random.randint(0, len(style_tensors)) batch_style.append(style_tensors[style_idx]) batch_style = torch.cat(batch_style, dim=0) # Forward pass output = model(content_batch, batch_style) # Multi-layer loss with torch.no_grad(): content_feats = loss_network(content_batch) style_feats = loss_network(batch_style) output_feats = loss_network(output) # Content loss from relu4_1 content_loss = mse_loss(output_feats[-1], content_feats[-1]) # Style loss from multiple layers style_loss = 0 style_weights = [0.2, 0.3, 0.5, 1.0] def gram_matrix(feat): b, c, h, w = feat.size() feat = feat.view(b, c, h * w) gram = torch.bmm(feat, feat.transpose(1, 2)) return gram / (c * h * w) for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)): output_gram = gram_matrix(output_feat) style_gram = gram_matrix(style_feat) style_loss += weight * mse_loss(output_gram, style_gram) style_loss /= len(style_weights) # Total loss loss = content_weight * content_loss + actual_style_weight * style_loss # Backward pass optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0) optimizer.step() epoch_loss += loss.item() epoch_content_loss += content_loss.item() epoch_style_loss += style_loss.item() total_steps += 1 # Progress callback if progress_callback and total_steps % 10 == 0: progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs progress_callback(progress, f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} " f"(C: {content_loss.item():.4f}, S: {style_loss.item():.4f})") # Step scheduler scheduler.step() # Print epoch stats avg_loss = epoch_loss / len(dataloader) print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, " f"Content={epoch_content_loss/len(dataloader):.4f}, " f"Style={epoch_style_loss/len(dataloader):.4f}") # Save checkpoint if (epoch + 1) % save_interval == 0: checkpoint_path = f'{system.models_dir}/{model_name}_epoch_{epoch+1}.pth' torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': avg_loss, 'model_type': 'adain' }, checkpoint_path) print(f"Saved checkpoint: {checkpoint_path}") # Save final model final_path = f'{system.models_dir}/{model_name}_final.pth' torch.save({ 'model_state_dict': model.state_dict(), 'model_type': 'adain' }, final_path) # Cleanup shutil.rmtree(temp_content_dir) if model: st.session_state['trained_adain_model'] = model st.session_state['trained_style_images'] = style_images st.session_state['model_path'] = final_path st.success("AdaIN training complete! 🎉") # Add to system's models system.lightweight_models[model_name] = model progress_bar.empty() status_text.empty() else: st.error("Please upload both style and content images") # Testing section with regional application if 'trained_adain_model' in st.session_state: st.markdown("---") st.header("Test Your AdaIN Model") # Application mode selection application_mode = st.radio("Application Mode", ["Whole Image", "Paint Region"], horizontal=True, help="Choose whether to apply style to entire image or paint specific regions") test_col1, test_col2, test_col3 = st.columns([1, 1, 1]) with test_col1: st.subheader("Test Options") # Test image selection test_source = st.radio("Test Image Source", ["Use Content Image", "Upload New", "Use Unsplash Image"], horizontal=True) test_image = None if test_source == "Use Content Image" and st.session_state.content_images_list: # Select from uploaded content images content_idx = st.selectbox("Select content image", range(len(st.session_state.content_images_list)), format_func=lambda x: f"Content Image {x+1}") test_image = Image.open(st.session_state.content_images_list[content_idx]).convert('RGB') elif test_source == "Use Unsplash Image": # Use current Unsplash image if available if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): test_image = st.session_state['current_image'] st.success("Using Unsplash image") else: st.info("Please search and select an image from the Style Transfer tab first") else: # Upload new image test_upload = st.file_uploader("Upload test image", type=['png', 'jpg', 'jpeg'], key="test_adain") if test_upload: test_image = Image.open(test_upload).convert('RGB') # Store test image in session state if test_image: st.session_state['adain_test_image'] = test_image # Style selection for testing if 'trained_style_images' in st.session_state and len(st.session_state['trained_style_images']) > 1: style_idx = st.selectbox("Select style", range(len(st.session_state['trained_style_images'])), format_func=lambda x: f"Style {x+1}") test_style = st.session_state['trained_style_images'][style_idx] elif 'trained_style_images' in st.session_state: test_style = st.session_state['trained_style_images'][0] st.info("Using the single trained style") else: test_style = None # IMPROVED DEFAULTS # Alpha blending control alpha = st.slider("Style Strength (Alpha)", 0.0, 2.0, 1.2, 0.1, help="0 = original content, 1 = full style transfer, >1 = stronger style") # Add tiling option - DEFAULT TO TRUE use_tiling = st.checkbox("Use Tiled Processing", value=True, # Default to True help="Process images in tiles for better quality. Recommended for ALL images.") # Initialize variables with default values brush_size = 30 drawing_mode = "freedraw" feather_radius = 10 # Regional painting options (only show if in paint mode) if application_mode == "Paint Region": st.markdown("---") st.subheader("Painting Options") brush_size = st.slider("Brush Size", 5, 100, 30) drawing_mode = st.selectbox("Drawing Tool", ["freedraw", "line", "rect", "circle", "polygon"], index=0) # Feather/blur the mask edges feather_radius = st.slider("Edge Softness", 0, 50, 10, help="Blur mask edges for smoother transitions") col_btn1, col_btn2 = st.columns(2) with col_btn1: if st.button("Clear Canvas", use_container_width=True): st.session_state['adain_canvas_result'] = None st.rerun() with col_btn2: if st.button("Reset Result", use_container_width=True): if 'adain_styled_result' in st.session_state: del st.session_state['adain_styled_result'] st.rerun() with test_col2: st.subheader("Canvas / Original") if application_mode == "Paint Region" and test_image: # Show canvas for painting display_img = resize_image_for_display(test_image, max_width=400, max_height=400) canvas_width, canvas_height = display_img.size st.info("Paint the areas where you want to apply the style") # Canvas for painting mask canvas_result = st_canvas( fill_color="rgba(255, 0, 0, 0.3)", # Red with transparency stroke_width=brush_size, stroke_color="rgba(255, 0, 0, 0.5)", background_image=display_img, update_streamlit=True, height=canvas_height, width=canvas_width, drawing_mode=drawing_mode, display_toolbar=True, key=f"adain_canvas_{brush_size}_{drawing_mode}" ) # Save canvas result if canvas_result: st.session_state['adain_canvas_result'] = canvas_result # Show style image below canvas if test_style: st.markdown("---") st.image(test_style, caption="Style Image", use_column_width=True) else: # Show original images if test_image: st.image(test_image, caption="Content Image", use_column_width=True) if test_style: st.image(test_style, caption="Style Image", use_column_width=True) with test_col3: st.subheader("Result") # Apply button apply_button = st.button("Apply Style", type="primary", use_container_width=True) if apply_button and test_image and test_style: with st.spinner("Applying style..."): if application_mode == "Whole Image": # Apply to whole image result = system.apply_adain_style( test_image, test_style, st.session_state['trained_adain_model'], alpha=alpha, use_tiling=use_tiling ) else: # Apply to painted region result = apply_adain_regional( test_image, test_style, st.session_state['trained_adain_model'], st.session_state.get('adain_canvas_result'), alpha=alpha, feather_radius=feather_radius, use_tiling=use_tiling ) if result: st.session_state['adain_styled_result'] = result # Show result if available if 'adain_styled_result' in st.session_state: st.image(st.session_state['adain_styled_result'], caption="Styled Result", use_column_width=True) # Quality tips with st.expander("💡 Tips for Better Quality"): st.markdown(""" - **Always use tiling** for best quality - Try **alpha > 1.0** (1.2-1.5) for stronger style - Use **multiple style images** when training - Train for **50+ epochs** for best results - If quality is still poor, retrain with **style weight = 200-500** """) # Download button buf = io.BytesIO() st.session_state['adain_styled_result'].save(buf, format='PNG') st.download_button( label="Download Result", data=buf.getvalue(), file_name=f"adain_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png", mime="image/png" ) # Model download section st.markdown("---") if 'model_path' in st.session_state and os.path.exists(st.session_state['model_path']): col_dl1, col_dl2 = st.columns(2) with col_dl1: with open(st.session_state['model_path'], 'rb') as f: st.download_button( label="Download Trained AdaIN Model", data=f.read(), file_name=f"{model_name}_final.pth", mime="application/octet-stream", use_container_width=True ) with col_dl2: st.info("This model can be loaded and used for real-time style transfer") # Add this helper function (place it before the tab or with other helper functions) def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False): """Apply AdaIN style transfer to a painted region only""" if content_image is None or style_image is None or model is None: return None try: # Get the mask from canvas if canvas_result is None or canvas_result.image_data is None: # No mask painted, apply to whole image return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) # Extract mask from canvas mask_data = canvas_result.image_data[:, :, 3] # Alpha channel mask = mask_data > 0 # Resize mask to match original image size original_size = content_image.size display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0]) if original_size != display_size: # Convert mask to PIL image for resizing mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L') mask_pil = mask_pil.resize(original_size, Image.NEAREST) mask = np.array(mask_pil) > 128 # Apply feathering to mask edges if requested if feather_radius > 0: from scipy.ndimage import gaussian_filter mask_float = mask.astype(np.float32) mask_float = gaussian_filter(mask_float, sigma=feather_radius) mask_float = np.clip(mask_float, 0, 1) else: mask_float = mask.astype(np.float32) # Apply style to entire image with tiling option styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) if styled_full is None: return None # Blend original and styled based on mask original_array = np.array(content_image, dtype=np.float32) styled_array = np.array(styled_full, dtype=np.float32) # Expand mask to 3 channels mask_3ch = np.stack([mask_float] * 3, axis=2) # Blend result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch result_array = np.clip(result_array, 0, 255).astype(np.uint8) return Image.fromarray(result_array) except Exception as e: print(f"Error applying regional AdaIN style: {e}") traceback.print_exc() return None # TAB 5: Batch Processing with tab5: st.header("Batch Processing") col1, col2 = st.columns(2) with col1: # Image source selection for batch batch_source = st.radio("Image Source", ["Upload Multiple", "Use Current Unsplash Image"], horizontal=True, key="batch_source") batch_files = [] if batch_source == "Upload Multiple": batch_files = st.file_uploader("Upload Images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True, key="batch_upload") else: # Use current Unsplash image if available if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): batch_files = [st.session_state['current_image']] st.success("Using current Unsplash image for batch processing") else: st.info("Please search and select an image from the Style Transfer tab first") processing_type = st.radio("Processing Type", ["CycleGAN", "Custom Trained Model"]) if processing_type == "CycleGAN": # Style configuration batch_styles = [] for i in range(3): with st.expander(f"Style {i+1}", expanded=(i==0)): style = st.selectbox(f"Select style", style_choices, key=f"batch_style_{i}") intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"batch_intensity_{i}") if style and intensity > 0: model_key = None for key, info in system.cyclegan_models.items(): if info['name'] == style: model_key = key break if model_key: batch_styles.append(('cyclegan', model_key, intensity)) batch_blend_mode = st.selectbox("Blend Mode", ["additive", "average", "maximum", "overlay", "screen"], index=0, key="batch_blend") else: # Custom model upload custom_model_file = st.file_uploader("Upload Trained Model (.pth)", type=['pth']) if st.button("Process Batch", type="primary", use_container_width=True): if batch_files: with st.spinner("Processing batch..."): progress_bar = st.progress(0) processed_images = [] if processing_type == "CycleGAN" and batch_styles: for idx, file in enumerate(batch_files): progress_bar.progress((idx + 1) / len(batch_files)) # Handle both file uploads and PIL images if isinstance(file, Image.Image): image = file else: image = Image.open(file).convert('RGB') result = system.blend_styles(image, batch_styles, batch_blend_mode) processed_images.append(result) elif processing_type == "Custom Trained Model" and custom_model_file: # Load custom model temp_model = tempfile.NamedTemporaryFile(delete=False, suffix='.pth') temp_model.write(custom_model_file.read()) temp_model.close() model = system.load_lightweight_model(temp_model.name) if model: for idx, file in enumerate(batch_files): progress_bar.progress((idx + 1) / len(batch_files)) # Handle both file uploads and PIL images if isinstance(file, Image.Image): image = file else: image = Image.open(file).convert('RGB') result = system.apply_lightweight_style(image, model) if result: processed_images.append(result) os.unlink(temp_model.name) if processed_images: # Create zip zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w') as zf: for idx, img in enumerate(processed_images): img_buffer = io.BytesIO() img.save(img_buffer, format='PNG') zf.writestr(f"styled_{idx+1:03d}.png", img_buffer.getvalue()) st.session_state['batch_results'] = processed_images st.session_state['batch_zip'] = zip_buffer.getvalue() progress_bar.empty() with col2: if 'batch_results' in st.session_state: st.header("Results") # Show gallery cols = st.columns(4) for idx, img in enumerate(st.session_state['batch_results'][:8]): cols[idx % 4].image(img, use_column_width=True) if len(st.session_state['batch_results']) > 8: st.info(f"Showing 8 of {len(st.session_state['batch_results'])} processed images") # Download zip st.download_button( label="Download All (ZIP)", data=st.session_state['batch_zip'], file_name=f"batch_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.zip", mime="application/zip" ) # TAB 6: Documentation with tab6: st.markdown(f""" ## Style Transfer System Documentation ### Available CycleGAN Models This system includes pre-trained bidirectional CycleGAN models: {chr(10).join([f'- **{info["name"]}**' for key, info in sorted(system.cyclegan_models.items(), key=lambda item: item[1]["name"])])} ### Features #### Style Transfer - Apply multiple styles simultaneously - Adjustable intensity for each style - Multiple blending modes for creative effects - **NEW**: Search and use images from Unsplash #### Regional Transform - Paint specific regions to apply different styles - Support for multiple regions with different styles - Adjustable brush size and drawing tools - Base style + regional overlays - Persistent brush strokes across regions - Optimized display for large images #### Video Processing - Frame-by-frame style transfer - Maintains temporal consistency - Supports all style combinations and blend modes - Enhanced codec compatibility #### Custom Training - Train on any artistic style with minimal data (1-50 images) - Automatic data augmentation for small datasets - Adjustable model complexity (3-12 residual blocks) ### Model Architecture - **CycleGAN models**: 9-12 residual blocks for high-quality transformations - **Lightweight models**: 3-12 residual blocks (customizable during training) - **Training approach**: Unpaired image-to-image translation ### Technical Details - **Framework**: PyTorch - **GPU Support**: CUDA acceleration when available - **Image Formats**: JPG, PNG, BMP - **Video Formats**: MP4, AVI, MOV - **Model Size**: ~45MB (CycleGAN), 5-15MB (Lightweight) ### Unsplash Integration To use Unsplash image search: 1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers) 2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY` 3. Search for images directly in the app 4. Automatic attribution for photographers ### Usage Tips 1. **For best results**: Use high-quality input images 2. **Style intensity**: Start with 1.0, adjust to taste 3. **Blending modes**: - 'Additive' for bold effects - 'Average' for subtle blends - 'Overlay' for dramatic contrasts 4. **Regional painting**: - Use larger brush for smooth transitions - Multiple thin layers work better than one thick layer - Previous regions remain visible as you paint new ones 5. **Custom training**: More diverse content images = better generalization 6. **Video processing**: Keep videos under 30 seconds for faster processing ### Regional Transform Guide The regional transform feature allows you to: 1. Define multiple regions by painting on the canvas 2. Assign different styles to each region 3. Control intensity per region 4. Apply an optional base style to the entire image 5. Blend regions using various modes **Tips for Regional Transform:** - Start with a base style for overall coherence - Use semi-transparent brushes for smoother transitions - Overlap regions for interesting blend effects - Experiment with different blend modes per region - All regions are visible while painting for better control """) # Footer st.markdown("---") st.markdown("Style transfer system with CycleGAN models and regional painting capabilities.")