Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| STYLE TRANSFER APP - Streamlit Version with Regional Transformations | |
| All existing features preserved + new local painting capabilities + Unsplash integration | |
| """ | |
| import os | |
| os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
| os.environ['TORCH_HOME'] = '/tmp/torch_cache' | |
| os.environ['HF_HOME'] = '/tmp/hf_cache' | |
| os.makedirs('/tmp/torch_cache', exist_ok=True) | |
| os.makedirs('/tmp/hf_cache', exist_ok=True) | |
| import streamlit as st | |
| from streamlit_drawable_canvas import st_canvas | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import glob | |
| import datetime | |
| import traceback | |
| import uuid | |
| import warnings | |
| import zipfile | |
| import io | |
| import json | |
| import time | |
| import shutil | |
| import requests | |
| import scipy | |
| try: | |
| import cv2 | |
| VIDEO_PROCESSING_AVAILABLE = True | |
| except ImportError: | |
| VIDEO_PROCESSING_AVAILABLE = False | |
| print("OpenCV not available - video processing disabled") | |
| import tempfile | |
| from pathlib import Path | |
| import colorsys | |
| warnings.filterwarnings("ignore") | |
| # Set page config | |
| st.set_page_config( | |
| page_title="Style Transfer Studio", | |
| # page_icon="", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for better UI | |
| st.markdown(""" | |
| <style> | |
| .stTabs [data-baseweb="tab-list"] { | |
| gap: 24px; | |
| } | |
| .stTabs [data-baseweb="tab"] { | |
| height: 50px; | |
| padding-left: 20px; | |
| padding-right: 20px; | |
| } | |
| .main > div { | |
| padding-top: 2rem; | |
| } | |
| .st-emotion-cache-1y4p8pa { | |
| max-width: 100%; | |
| } | |
| /* Fix canvas container */ | |
| .stDrawableCanvas { | |
| margin: 0 auto; | |
| } | |
| /* Unsplash grid styling */ | |
| .unsplash-grid img { | |
| border-radius: 8px; | |
| cursor: pointer; | |
| transition: transform 0.2s; | |
| } | |
| .unsplash-grid img:hover { | |
| transform: scale(1.05); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Force CUDA if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(0) | |
| # Set CUDA to be deterministic for consistency | |
| torch.backends.cudnn.benchmark = True | |
| print("CUDA device set") | |
| print(f"CUDA version: {torch.version.cuda}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| # GPU SETUP | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| if device.type == 'cuda': | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") | |
| print(f"Current GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
| # =========================== | |
| # UNSPLASH API INTEGRATION | |
| # =========================== | |
| class UnsplashAPI: | |
| """Simple Unsplash API integration""" | |
| def __init__(self, access_key=None): | |
| # Use environment variable only - no secrets.toml warning | |
| if access_key: | |
| self.access_key = access_key | |
| else: | |
| self.access_key = os.environ.get("UNSPLASH_ACCESS_KEY") | |
| self.base_url = "https://api.unsplash.com" | |
| def search_photos(self, query, per_page=20, page=1, orientation=None): | |
| """Search photos on Unsplash""" | |
| if not self.access_key: | |
| return None, "No Unsplash API key configured" | |
| headers = {"Authorization": f"Client-ID {self.access_key}"} | |
| params = { | |
| "query": query, | |
| "per_page": per_page, | |
| "page": page | |
| } | |
| if orientation: | |
| params["orientation"] = orientation # "landscape", "portrait", "squarish" | |
| try: | |
| response = requests.get( | |
| f"{self.base_url}/search/photos", | |
| headers=headers, | |
| params=params, | |
| timeout=10 | |
| ) | |
| response.raise_for_status() | |
| return response.json(), None | |
| except requests.exceptions.RequestException as e: | |
| return None, f"Error searching Unsplash: {str(e)}" | |
| def get_random_photos(self, count=12, collections=None, query=None): | |
| """Get random photos from Unsplash""" | |
| if not self.access_key: | |
| return None, "No Unsplash API key configured" | |
| headers = {"Authorization": f"Client-ID {self.access_key}"} | |
| params = {"count": count} | |
| if collections: | |
| params["collections"] = collections | |
| if query: | |
| params["query"] = query | |
| try: | |
| response = requests.get( | |
| f"{self.base_url}/photos/random", | |
| headers=headers, | |
| params=params, | |
| timeout=10 | |
| ) | |
| response.raise_for_status() | |
| return response.json(), None | |
| except requests.exceptions.RequestException as e: | |
| return None, f"Error getting random photos: {str(e)}" | |
| def download_photo(self, photo_url, size="regular"): | |
| """Download photo from URL""" | |
| try: | |
| # Add fm=jpg&q=80 for consistent format and quality | |
| if "?" in photo_url: | |
| photo_url += "&fm=jpg&q=80" | |
| else: | |
| photo_url += "?fm=jpg&q=80" | |
| response = requests.get(photo_url, timeout=30) | |
| response.raise_for_status() | |
| return Image.open(io.BytesIO(response.content)).convert('RGB') | |
| except Exception as e: | |
| st.error(f"Error downloading image: {str(e)}") | |
| return None | |
| def trigger_download(self, download_location): | |
| """Trigger download event (required by Unsplash API)""" | |
| if not self.access_key or not download_location: | |
| return | |
| headers = {"Authorization": f"Client-ID {self.access_key}"} | |
| try: | |
| requests.get(download_location, headers=headers, timeout=5) | |
| except: | |
| pass # Don't fail if tracking fails | |
| # =========================== | |
| # MODEL ARCHITECTURES | |
| # =========================== | |
| class LightweightResidualBlock(nn.Module): | |
| """Lightweight residual block with depthwise separable convolutions""" | |
| def __init__(self, channels): | |
| super(LightweightResidualBlock, self).__init__() | |
| self.depthwise = nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(channels, channels, 3, groups=channels), | |
| nn.InstanceNorm2d(channels, affine=True), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.pointwise = nn.Sequential( | |
| nn.Conv2d(channels, channels, 1), | |
| nn.InstanceNorm2d(channels, affine=True) | |
| ) | |
| def forward(self, x): | |
| return x + self.pointwise(self.depthwise(x)) | |
| class ResidualBlock(nn.Module): | |
| """Standard residual block for CycleGAN""" | |
| def __init__(self, in_features): | |
| super(ResidualBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_features, in_features, 3), | |
| nn.InstanceNorm2d(in_features, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_features, in_features, 3), | |
| nn.InstanceNorm2d(in_features, affine=True) | |
| ) | |
| def forward(self, x): | |
| return x + self.block(x) | |
| class Generator(nn.Module): | |
| def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=9): | |
| super(Generator, self).__init__() | |
| # Initial convolution block | |
| model = [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(input_nc, 64, 7), | |
| nn.InstanceNorm2d(64, affine=True), | |
| nn.ReLU(inplace=True) | |
| ] | |
| # Downsampling | |
| in_features = 64 | |
| out_features = in_features * 2 | |
| for _ in range(2): | |
| model += [ | |
| nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | |
| nn.InstanceNorm2d(out_features, affine=True), | |
| nn.ReLU(inplace=True) | |
| ] | |
| in_features = out_features | |
| out_features = in_features * 2 | |
| # Residual blocks | |
| for _ in range(n_residual_blocks): | |
| model += [ResidualBlock(in_features)] | |
| # Upsampling | |
| out_features = in_features // 2 | |
| for _ in range(2): | |
| model += [ | |
| nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), | |
| nn.InstanceNorm2d(out_features, affine=True), | |
| nn.ReLU(inplace=True) | |
| ] | |
| in_features = out_features | |
| out_features = in_features // 2 | |
| # Output layer | |
| model += [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(64, output_nc, 7), | |
| nn.Tanh() | |
| ] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| return self.model(x) | |
| class LightweightStyleNet(nn.Module): | |
| """Lightweight network for fast style transfer training""" | |
| def __init__(self, n_residual_blocks=5): | |
| super(LightweightStyleNet, self).__init__() | |
| # Encoder | |
| self.encoder = nn.Sequential( | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(3, 32, 9, stride=1), | |
| nn.InstanceNorm2d(32, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), | |
| nn.InstanceNorm2d(64, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), | |
| nn.InstanceNorm2d(128, affine=True), | |
| nn.ReLU(inplace=True) | |
| ) | |
| # Residual blocks | |
| res_blocks = [] | |
| for _ in range(n_residual_blocks): | |
| res_blocks.append(LightweightResidualBlock(128)) | |
| self.res_blocks = nn.Sequential(*res_blocks) | |
| # Decoder | |
| self.decoder = nn.Sequential( | |
| nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), | |
| nn.InstanceNorm2d(64, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), | |
| nn.InstanceNorm2d(32, affine=True), | |
| nn.ReLU(inplace=True), | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(32, 3, 9, stride=1), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| h = self.encoder(x) | |
| h = self.res_blocks(h) | |
| h = self.decoder(h) | |
| return h | |
| class SimpleVGGFeatures(nn.Module): | |
| """Extract features from VGG19 for perceptual loss calculation""" | |
| def __init__(self): | |
| super(SimpleVGGFeatures, self).__init__() | |
| try: | |
| vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features | |
| except: | |
| vgg = models.vgg19(pretrained=True).features | |
| self.features = nn.Sequential(*list(vgg.children())[:21]) | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| return self.features(x) | |
| # =========================== | |
| # ADAIN ARCHITECTURE | |
| # =========================== | |
| class AdaIN(nn.Module): | |
| """Adaptive Instance Normalization layer""" | |
| def __init__(self): | |
| super(AdaIN, self).__init__() | |
| def calc_mean_std(self, feat, eps=1e-5): | |
| # Calculate mean and std for AdaIN | |
| size = feat.size() | |
| assert (len(size) == 4) | |
| N, C = size[:2] | |
| feat_var = feat.view(N, C, -1).var(dim=2) + eps | |
| feat_std = feat_var.sqrt().view(N, C, 1, 1) | |
| feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) | |
| return feat_mean, feat_std | |
| def forward(self, content_feat, style_feat): | |
| size = content_feat.size() | |
| style_mean, style_std = self.calc_mean_std(style_feat) | |
| content_mean, content_std = self.calc_mean_std(content_feat) | |
| normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) | |
| return normalized_feat * style_std.expand(size) + style_mean.expand(size) | |
| class VGGEncoder(nn.Module): | |
| """VGG-based encoder for AdaIN""" | |
| def __init__(self): | |
| super(VGGEncoder, self).__init__() | |
| # Load pretrained VGG19 | |
| try: | |
| vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features | |
| except: | |
| vgg = models.vgg19(pretrained=True).features | |
| # Encoder uses layers up to relu4_1 | |
| self.enc1 = nn.Sequential(*list(vgg.children())[:2]) # conv1_1, relu1_1 | |
| self.enc2 = nn.Sequential(*list(vgg.children())[2:7]) # up to relu2_1 | |
| self.enc3 = nn.Sequential(*list(vgg.children())[7:12]) # up to relu3_1 | |
| self.enc4 = nn.Sequential(*list(vgg.children())[12:21]) # up to relu4_1 | |
| # Freeze encoder weights | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def encode(self, x): | |
| """Get only the final features for AdaIN""" | |
| h1 = self.enc1(x) | |
| h2 = self.enc2(h1) | |
| h3 = self.enc3(h2) | |
| h4 = self.enc4(h3) | |
| return h4 | |
| def forward(self, x): | |
| h1 = self.enc1(x) | |
| h2 = self.enc2(h1) | |
| h3 = self.enc3(h2) | |
| h4 = self.enc4(h3) | |
| return h4, h3, h2, h1 # Return intermediate features for skip connections | |
| class AdaINDecoder(nn.Module): | |
| """Decoder for AdaIN style transfer""" | |
| def __init__(self): | |
| super(AdaINDecoder, self).__init__() | |
| # Decoder mirrors encoder but in reverse | |
| self.dec4 = nn.Sequential( | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(512, 256, (3, 3)), | |
| nn.ReLU(), | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| ) | |
| self.dec3 = nn.Sequential( | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(256, 256, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(256, 128, (3, 3)), | |
| nn.ReLU(), | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| ) | |
| self.dec2 = nn.Sequential( | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(128, 128, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(128, 64, (3, 3)), | |
| nn.ReLU(), | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| ) | |
| self.dec1 = nn.Sequential( | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(64, 64, (3, 3)), | |
| nn.ReLU(), | |
| nn.ReflectionPad2d((1, 1, 1, 1)), | |
| nn.Conv2d(64, 3, (3, 3)), | |
| ) | |
| def forward(self, x): | |
| h = self.dec4(x) | |
| h = self.dec3(h) | |
| h = self.dec2(h) | |
| h = self.dec1(h) | |
| return h | |
| class AdaINStyleTransfer(nn.Module): | |
| """Complete AdaIN style transfer network""" | |
| def __init__(self): | |
| super(AdaINStyleTransfer, self).__init__() | |
| self.encoder = VGGEncoder() | |
| self.decoder = AdaINDecoder() | |
| self.adain = AdaIN() | |
| # Only decoder needs to be trained | |
| self.encoder.eval() | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| def encode(self, x): | |
| return self.encoder.encode(x) # Use the encode method | |
| def forward(self, content, style, alpha=1.0): | |
| # Encode content and style | |
| content_feat = self.encode(content) | |
| style_feat = self.encode(style) | |
| # Apply AdaIN | |
| feat = self.adain(content_feat, style_feat) | |
| # Alpha blending in feature space | |
| if alpha < 1.0: | |
| feat = alpha * feat + (1 - alpha) * content_feat | |
| # Decode | |
| return self.decoder(feat) | |
| # =========================== | |
| # DATASET AND LOSS FUNCTIONS | |
| # =========================== | |
| class StyleTransferDataset(Dataset): | |
| """Dataset for training style transfer models with augmentation support""" | |
| def __init__(self, content_dir, transform=None, augment_factor=1): | |
| self.content_dir = Path(content_dir) | |
| self.transform = transform | |
| self.augment_factor = augment_factor | |
| extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] | |
| self.images = [] | |
| for ext in extensions: | |
| self.images.extend(list(self.content_dir.glob(ext))) | |
| self.images.extend(list(self.content_dir.glob(ext.upper()))) | |
| print(f"Found {len(self.images)} content images") | |
| self.augmented_images = self.images * self.augment_factor | |
| if self.augment_factor > 1: | |
| print(f"Dataset augmented {self.augment_factor}x to {len(self.augmented_images)} samples") | |
| def __len__(self): | |
| return len(self.augmented_images) | |
| def __getitem__(self, idx): | |
| img_path = self.augmented_images[idx % len(self.images)] | |
| image = Image.open(img_path).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image | |
| class PerceptualLoss(nn.Module): | |
| """Perceptual loss using VGG features""" | |
| def __init__(self, vgg_features): | |
| super(PerceptualLoss, self).__init__() | |
| self.vgg = vgg_features | |
| self.mse = nn.MSELoss() | |
| def gram_matrix(self, features): | |
| b, c, h, w = features.size() | |
| features = features.view(b, c, h * w) | |
| gram = torch.bmm(features, features.transpose(1, 2)) | |
| return gram / (c * h * w) | |
| def forward(self, generated, content, style, content_weight=1.0, style_weight=1e5): | |
| gen_feat = self.vgg(generated) | |
| content_feat = self.vgg(content) | |
| style_feat = self.vgg(style) | |
| content_loss = self.mse(gen_feat, content_feat) | |
| gen_gram = self.gram_matrix(gen_feat) | |
| style_gram = self.gram_matrix(style_feat) | |
| style_loss = self.mse(gen_gram, style_gram) | |
| total_loss = content_weight * content_loss + style_weight * style_loss | |
| return total_loss, content_loss, style_loss | |
| # =========================== | |
| # VIDEO PROCESSING | |
| # =========================== | |
| class VideoProcessor: | |
| """Process videos frame by frame with style transfer""" | |
| def __init__(self, system): | |
| self.system = system | |
| def process_video(self, video_path, style_configs, blend_mode, progress_callback=None): | |
| """Process a video file with style transfer""" | |
| if not VIDEO_PROCESSING_AVAILABLE: | |
| print("Video processing requires OpenCV (cv2) - please install it") | |
| return None | |
| try: | |
| # Handle both string path and file object | |
| if hasattr(video_path, 'name'): | |
| video_path = video_path.name | |
| # Open video | |
| cap = cv2.VideoCapture(video_path) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Create temporary output file - always start with mp4 | |
| temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
| temp_output.close() | |
| # Try different codecs - prioritize MP4-compatible ones | |
| codecs_to_try = [ | |
| ('mp4v', '.mp4'), # MPEG-4 - most compatible MP4 codec | |
| ('MP4V', '.mp4'), # Alternative case | |
| ('FMP4', '.mp4'), # Another MPEG-4 variant | |
| ('DIVX', '.mp4'), # DivX (sometimes works with MP4) | |
| ('XVID', '.avi'), # If MP4 fails, try AVI | |
| ('MJPG', '.avi'), # Motion JPEG - fallback | |
| ('DIV3', '.avi'), # DivX3 | |
| (0, '.avi') # Uncompressed (last resort) | |
| ] | |
| out = None | |
| output_path = None | |
| used_codec = None | |
| for codec_str, ext in codecs_to_try: | |
| try: | |
| # Update output filename with appropriate extension | |
| if ext == '.mp4': | |
| output_path = temp_output.name | |
| else: | |
| output_path = temp_output.name.replace('.mp4', ext) | |
| if codec_str == 0: | |
| fourcc = 0 | |
| print("Using uncompressed video (larger file size)") | |
| else: | |
| fourcc = cv2.VideoWriter_fourcc(*codec_str) | |
| print(f"Trying codec: {codec_str} for {ext}") | |
| # Create writer with specific parameters | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True) | |
| if out.isOpened(): | |
| # Test write a black frame to ensure it really works | |
| test_frame = np.zeros((height, width, 3), dtype=np.uint8) | |
| out.write(test_frame) | |
| out.release() | |
| # Re-open for actual writing | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True) | |
| if out.isOpened(): | |
| used_codec = codec_str | |
| print(f"✓ Successfully using codec: {codec_str} with {ext}") | |
| break | |
| out.release() | |
| out = None | |
| except Exception as e: | |
| print(f"Failed with codec {codec_str}: {e}") | |
| if out: | |
| out.release() | |
| out = None | |
| continue | |
| if out is None or not out.isOpened(): | |
| # Last resort: save frames as images and create video differently | |
| print("Standard codecs failed. Trying alternative approach...") | |
| return self._process_with_frame_saving(cap, style_configs, blend_mode, fps, width, height, total_frames, progress_callback) | |
| # Process frames | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert BGR to RGB and to PIL Image | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_frame = Image.fromarray(rgb_frame) | |
| # Apply style transfer | |
| styled_frame = self.system.blend_styles(pil_frame, style_configs, blend_mode) | |
| # Convert back to BGR for video | |
| styled_array = np.array(styled_frame) | |
| bgr_frame = cv2.cvtColor(styled_array, cv2.COLOR_RGB2BGR) | |
| # Ensure frame is correct size and type | |
| if bgr_frame.shape[:2] != (height, width): | |
| bgr_frame = cv2.resize(bgr_frame, (width, height)) | |
| out.write(bgr_frame) | |
| frame_count += 1 | |
| if progress_callback and frame_count % 10 == 0: | |
| progress = frame_count / total_frames | |
| progress_callback(progress, f"Processing frame {frame_count}/{total_frames}") | |
| cap.release() | |
| out.release() | |
| # Verify the output file | |
| if not os.path.exists(output_path) or os.path.getsize(output_path) < 1000: | |
| print(f"Output file is too small or doesn't exist (size: {os.path.getsize(output_path) if os.path.exists(output_path) else 0} bytes)") | |
| return None | |
| print(f"Video successfully saved to: {output_path}") | |
| print(f"File size: {os.path.getsize(output_path) / 1024 / 1024:.2f} MB") | |
| print(f"Format: {os.path.splitext(output_path)[1]}") | |
| print(f"Codec used: {used_codec}") | |
| # Clean up original temp file if different | |
| if output_path != temp_output.name and os.path.exists(temp_output.name): | |
| try: | |
| os.unlink(temp_output.name) | |
| except: | |
| pass | |
| return output_path | |
| except Exception as e: | |
| print(f"Error processing video: {e}") | |
| traceback.print_exc() | |
| return None | |
| def _process_with_frame_saving(self, cap, style_configs, blend_mode, fps, width, height, total_frames, progress_callback): | |
| """Alternative processing method: save frames then combine""" | |
| try: | |
| print("Using frame-saving fallback method...") | |
| temp_dir = tempfile.mkdtemp() | |
| frame_count = 0 | |
| frame_paths = [] | |
| # Process and save frames | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Process frame | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_frame = Image.fromarray(rgb_frame) | |
| styled_frame = self.system.blend_styles(pil_frame, style_configs, blend_mode) | |
| # Save frame | |
| frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png") | |
| styled_frame.save(frame_path) | |
| frame_paths.append(frame_path) | |
| frame_count += 1 | |
| if progress_callback and frame_count % 10 == 0: | |
| progress = frame_count / total_frames | |
| progress_callback(progress, f"Processing frame {frame_count}/{total_frames}") | |
| cap.release() | |
| if not frame_paths: | |
| return None | |
| # Try to create video from frames | |
| output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name | |
| # Read first frame to get size | |
| first_frame = cv2.imread(frame_paths[0]) | |
| h, w = first_frame.shape[:2] | |
| # Try simple mp4v codec | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) | |
| if out.isOpened(): | |
| for frame_path in frame_paths: | |
| frame = cv2.imread(frame_path) | |
| out.write(frame) | |
| out.release() | |
| # Clean up frames | |
| shutil.rmtree(temp_dir) | |
| if os.path.exists(output_path) and os.path.getsize(output_path) > 1000: | |
| print(f"Successfully created video using frame-saving method") | |
| return output_path | |
| # Clean up | |
| shutil.rmtree(temp_dir) | |
| return None | |
| except Exception as e: | |
| print(f"Frame-saving method failed: {e}") | |
| if 'temp_dir' in locals() and os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir) | |
| return None | |
| # =========================== | |
| # MAIN STYLE TRANSFER SYSTEM | |
| # =========================== | |
| class StyleTransferSystem: | |
| def __init__(self): | |
| self.device = device | |
| self.cyclegan_models = {} | |
| self.loaded_generators = {} | |
| self.lightweight_models = {} | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| self.inverse_transform = transforms.Compose([ | |
| transforms.Normalize((-1, -1, -1), (2, 2, 2)), | |
| transforms.ToPILImage() | |
| ]) | |
| self.vgg_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| self.discover_cyclegan_models() | |
| self.models_dir = '/tmp/trained_models' | |
| os.makedirs(self.models_dir, exist_ok=True) | |
| if VIDEO_PROCESSING_AVAILABLE: | |
| self.video_processor = VideoProcessor(self) | |
| # Test GPU functionality | |
| if self.device.type == 'cuda': | |
| self._test_gpu() | |
| def _test_gpu(self): | |
| """Test if GPU is working correctly""" | |
| try: | |
| print("\nTesting GPU functionality...") | |
| with torch.no_grad(): | |
| # Create a small test tensor | |
| test_tensor = torch.randn(1, 3, 256, 256).to(self.device) | |
| print(f"Test tensor device: {test_tensor.device}") | |
| # Perform a simple operation | |
| result = test_tensor * 2.0 | |
| print(f"Result device: {result.device}") | |
| # Check memory usage | |
| print(f"GPU memory after test: {torch.cuda.memory_allocated() / 1e9:.3f} GB") | |
| # Clean up | |
| del test_tensor, result | |
| torch.cuda.empty_cache() | |
| print("✓ GPU test successful!\n") | |
| except Exception as e: | |
| print(f"✗ GPU test failed: {e}\n") | |
| traceback.print_exc() | |
| def discover_cyclegan_models(self): | |
| """Find all available CycleGAN models including both AB and BA directions""" | |
| print("\nDiscovering CycleGAN models...") | |
| # Updated patterns to match your directory structure | |
| patterns = [ | |
| './models/*_best_*/*generator_*.pth', | |
| './models/*_best_*/*.pth', | |
| './models/*/*generator*.pth', | |
| './models/*/*.pth' | |
| ] | |
| all_files = set() | |
| for pattern in patterns: | |
| files = glob.glob(pattern) | |
| if files: | |
| print(f"Found in {pattern}: {len(files)} items") | |
| all_files.update(files) | |
| # Also check if models directory exists and list contents | |
| if os.path.exists('./models'): | |
| print(f"\nModels directory contents:") | |
| for folder in os.listdir('./models'): | |
| folder_path = os.path.join('./models', folder) | |
| if os.path.isdir(folder_path): | |
| print(f" {folder}/") | |
| for file in os.listdir(folder_path): | |
| print(f" - {file}") | |
| if file.endswith('.pth'): | |
| all_files.add(os.path.join(folder_path, file)) | |
| # Group files by base model name | |
| model_files = {} | |
| for path in all_files: | |
| # Skip normal models | |
| if 'normal' in path.lower(): | |
| continue | |
| filename = os.path.basename(path) | |
| folder_name = os.path.basename(os.path.dirname(path)) | |
| # Extract base name from folder name | |
| if '_best_' in folder_name: | |
| base_name = folder_name.split('_best_')[0] | |
| else: | |
| base_name = folder_name | |
| if base_name not in model_files: | |
| model_files[base_name] = {'AB': None, 'BA': None} | |
| # Check filename for direction | |
| if 'generator_AB' in filename or 'g_AB' in filename or 'G_AB' in filename: | |
| model_files[base_name]['AB'] = path | |
| elif 'generator_BA' in filename or 'g_BA' in filename or 'G_BA' in filename: | |
| model_files[base_name]['BA'] = path | |
| elif 'generator' in filename.lower() and not any(x in filename for x in ['AB', 'BA']): | |
| # If no direction specified, assume it's AB | |
| if model_files[base_name]['AB'] is None: | |
| model_files[base_name]['AB'] = path | |
| # Create display names for models | |
| model_display_map = { | |
| 'photo_bokeh': ('Bokeh', 'Sharp'), | |
| 'photo_golden': ('Golden Hour', 'Normal Light'), | |
| 'photo_monet': ('Monet Style', 'Photo'), | |
| 'photo_seurat': ('Seurat Style', 'Photo'), | |
| 'day_night': ('Night', 'Day'), | |
| 'summer_winter': ('Winter', 'Summer'), | |
| 'foggy_clear': ('Clear', 'Foggy') | |
| } | |
| # Register available models | |
| for base_name, files in model_files.items(): | |
| clean_name = base_name.lower().replace('-', '_') | |
| if clean_name in model_display_map: | |
| style_from, style_to = model_display_map[clean_name] | |
| # Register AB direction if available | |
| if files['AB']: | |
| display_name = f"{style_to} to {style_from}" | |
| model_key = f"{clean_name}_AB" | |
| self.cyclegan_models[model_key] = { | |
| 'path': files['AB'], | |
| 'name': display_name, | |
| 'base_name': base_name, | |
| 'direction': 'AB' | |
| } | |
| print(f"Registered: {display_name} ({model_key}) -> {files['AB']}") | |
| # Register BA direction if available | |
| if files['BA']: | |
| display_name = f"{style_from} to {style_to}" | |
| model_key = f"{clean_name}_BA" | |
| self.cyclegan_models[model_key] = { | |
| 'path': files['BA'], | |
| 'name': display_name, | |
| 'base_name': base_name, | |
| 'direction': 'BA' | |
| } | |
| print(f"Registered: {display_name} ({model_key}) -> {files['BA']}") | |
| if not self.cyclegan_models: | |
| print("No CycleGAN models found!") | |
| print("Make sure your model files are in the ./models directory") | |
| else: | |
| print(f"\nFound {len(self.cyclegan_models)} CycleGAN models\n") | |
| def detect_architecture(self, state_dict): | |
| """Detect the number of residual blocks in CycleGAN model""" | |
| residual_keys = [k for k in state_dict.keys() if 'model.' in k and '.block.' in k] | |
| if not residual_keys: | |
| return 9 | |
| block_indices = set() | |
| for key in residual_keys: | |
| parts = key.split('.') | |
| for i in range(len(parts) - 1): | |
| if parts[i] == 'model' and parts[i+1].isdigit(): | |
| block_indices.add(int(parts[i+1])) | |
| break | |
| n_blocks = len(block_indices) | |
| return n_blocks if n_blocks > 0 else 9 | |
| def load_cyclegan_model(self, model_key): | |
| """Load a CycleGAN model""" | |
| if model_key in self.loaded_generators: | |
| # Ensure cached model is on the correct device | |
| model = self.loaded_generators[model_key] | |
| model = model.to(self.device) | |
| return model | |
| if model_key not in self.cyclegan_models: | |
| print(f"Model {model_key} not found!") | |
| return None | |
| model_info = self.cyclegan_models[model_key] | |
| try: | |
| print(f"Loading {model_info['name']} from {model_info['path']}...") | |
| print(f"Target device: {self.device}") | |
| # Load with explicit map_location to ensure it goes to the right device | |
| state_dict = torch.load(model_info['path'], map_location=self.device) | |
| if 'generator' in state_dict: | |
| state_dict = state_dict['generator'] | |
| n_blocks = self.detect_architecture(state_dict) | |
| print(f"Detected {n_blocks} residual blocks") | |
| generator = Generator(n_residual_blocks=n_blocks) | |
| try: | |
| generator.load_state_dict(state_dict, strict=True) | |
| print(f"Loaded with strict=True") | |
| except: | |
| generator.load_state_dict(state_dict, strict=False) | |
| print(f"Loaded with strict=False") | |
| # Move to device BEFORE any precision changes | |
| generator = generator.to(self.device) | |
| generator.eval() | |
| # Check if model is actually on GPU | |
| print(f"Model device after .to(): {next(generator.parameters()).device}") | |
| # Skip half precision to ensure GPU usage - half precision can cause issues | |
| if self.device.type == 'cuda': | |
| print(f"Using full precision (fp32) on GPU") | |
| # Test GPU usage | |
| with torch.no_grad(): | |
| test_input = torch.randn(1, 3, 256, 256).to(self.device) | |
| _ = generator(test_input) | |
| print(f"GPU test successful") | |
| torch.cuda.empty_cache() | |
| self.loaded_generators[model_key] = generator | |
| print(f"Successfully loaded {model_info['name']} on {self.device}") | |
| return generator | |
| except Exception as e: | |
| print(f"Failed to load {model_info['name']}: {e}") | |
| traceback.print_exc() | |
| return None | |
| def apply_cyclegan_style(self, image, model_key, intensity=1.0): | |
| """Apply a CycleGAN style to an image""" | |
| if image is None or model_key not in self.cyclegan_models: | |
| return None | |
| model_info = self.cyclegan_models[model_key] | |
| generator = self.load_cyclegan_model(model_key) | |
| if generator is None: | |
| print(f"Could not load model for {model_info['name']}") | |
| return None | |
| try: | |
| # Ensure model is on GPU | |
| generator = generator.to(self.device) | |
| # Debug GPU usage | |
| if self.device.type == 'cuda': | |
| print(f"GPU Memory before style transfer: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
| original_size = image.size | |
| w, h = image.size | |
| new_w = ((w + 31) // 32) * 32 | |
| new_h = ((h + 31) // 32) * 32 | |
| max_size = 1024 if self.device.type == 'cuda' else 512 | |
| if new_w > max_size or new_h > max_size: | |
| ratio = min(max_size / new_w, max_size / new_h) | |
| new_w = int(new_w * ratio) | |
| new_h = int(new_h * ratio) | |
| new_w = ((new_w + 31) // 32) * 32 | |
| new_h = ((new_h + 31) // 32) * 32 | |
| image_resized = image.resize((new_w, new_h), Image.LANCZOS) | |
| img_tensor = self.transform(image_resized).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| # Skip half precision for now to ensure GPU usage | |
| if self.device.type == 'cuda': | |
| torch.cuda.synchronize() # Ensure GPU operations complete | |
| torch.cuda.empty_cache() | |
| output = generator(img_tensor) | |
| # Ensure output is on the same device | |
| output = output.to(self.device) | |
| if self.device.type == 'cuda': | |
| torch.cuda.synchronize() # Wait for GPU to finish | |
| output_img = self.inverse_transform(output.squeeze(0).cpu()) | |
| output_img = output_img.resize(original_size, Image.LANCZOS) | |
| if self.device.type == 'cuda': | |
| print(f"GPU Memory after style transfer: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
| torch.cuda.empty_cache() | |
| if intensity < 1.0: | |
| output_array = np.array(output_img, dtype=np.float32) | |
| original_array = np.array(image, dtype=np.float32) | |
| blended = original_array * (1 - intensity) + output_array * intensity | |
| output_img = Image.fromarray(blended.astype(np.uint8)) | |
| return output_img | |
| except Exception as e: | |
| print(f"Error applying style {model_info['name']}: {e}") | |
| print(f"Device: {self.device}") | |
| print(f"Model device: {next(generator.parameters()).device}") | |
| traceback.print_exc() | |
| return None | |
| def train_lightweight_model(self, style_image, content_dir, model_name, | |
| epochs=30, batch_size=4, lr=1e-3, | |
| save_interval=5, style_weight=1e5, content_weight=1.0, | |
| n_residual_blocks=5, progress_callback=None): | |
| """Train a lightweight style transfer model""" | |
| model = LightweightStyleNet(n_residual_blocks=n_residual_blocks).to(self.device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| print(f"Model architecture: {n_residual_blocks} residual blocks") | |
| print(f"Training device: {self.device}") | |
| # Verify model is on GPU | |
| if self.device.type == 'cuda': | |
| print(f"Model on GPU: {next(model.parameters()).device}") | |
| print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
| # Calculate augmentation factor | |
| num_content_images = len(list(Path(content_dir).glob('*'))) | |
| if num_content_images < 5: | |
| augment_factor = 20 | |
| elif num_content_images < 10: | |
| augment_factor = 10 | |
| elif num_content_images < 20: | |
| augment_factor = 5 | |
| else: | |
| augment_factor = 1 | |
| # Create dataset with augmentation | |
| if num_content_images < 10: | |
| transform = transforms.Compose([ | |
| transforms.RandomResizedCrop(256, scale=(0.7, 1.2)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| print(f"Using heavy augmentation due to limited images ({num_content_images} provided)") | |
| else: | |
| transform = transforms.Compose([ | |
| transforms.Resize(286), | |
| transforms.RandomCrop(256), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| dataset = StyleTransferDataset(content_dir, transform=transform, augment_factor=augment_factor) | |
| print(f"Training configuration:") | |
| print(f" - Original images: {num_content_images}") | |
| print(f" - Augmentation factor: {augment_factor}x") | |
| print(f" - Total training samples: {len(dataset)}") | |
| print(f" - Residual blocks: {n_residual_blocks}") | |
| print(f" - Batch size: {int(batch_size)}") | |
| print(f" - Epochs: {epochs}") | |
| # Adjust batch size for small datasets | |
| if num_content_images == 1: | |
| if n_residual_blocks >= 9 and int(batch_size) > 1: | |
| actual_batch_size = 1 | |
| print(f"Reduced batch size to 1 for single image + {n_residual_blocks} blocks") | |
| elif int(batch_size) > 2: | |
| actual_batch_size = 2 | |
| print(f"Reduced batch size to 2 for single image training") | |
| else: | |
| actual_batch_size = min(int(batch_size), len(dataset)) | |
| else: | |
| actual_batch_size = min(int(batch_size), len(dataset)) | |
| dataloader = DataLoader(dataset, batch_size=actual_batch_size, shuffle=True, | |
| num_workers=0 if num_content_images < 10 else 2) | |
| # Prepare style image | |
| style_transform = transforms.Compose([ | |
| transforms.Resize(800), | |
| transforms.CenterCrop(768) | |
| ]) | |
| style_pil = style_transform(style_image) | |
| style_tensor = self.vgg_transform(style_pil).unsqueeze(0).to(self.device) | |
| # Create VGG features extractor for loss | |
| vgg_features = SimpleVGGFeatures().to(self.device).eval() | |
| print(f"VGG features on device: {next(vgg_features.parameters()).device}") | |
| # Extract style features once | |
| with torch.no_grad(): | |
| style_features = vgg_features(style_tensor) | |
| # Loss function | |
| perceptual_loss = PerceptualLoss(vgg_features) | |
| # Training loop | |
| model.train() | |
| total_steps = 0 | |
| for epoch in range(epochs): | |
| epoch_loss = 0 | |
| for batch_idx, content_batch in enumerate(dataloader): | |
| content_batch = content_batch.to(self.device) | |
| # Forward pass | |
| output = model(content_batch) | |
| # Ensure all tensors have the same size | |
| target_size = (256, 256) | |
| # Convert for VGG | |
| output_vgg = [] | |
| content_vgg = [] | |
| for i in range(output.size(0)): | |
| # Denormalize from [-1, 1] to [0, 1] | |
| out_img = output[i] * 0.5 + 0.5 | |
| cont_img = content_batch[i] * 0.5 + 0.5 | |
| # Ensure exact size match | |
| if out_img.shape[1:] != (target_size[0], target_size[1]): | |
| out_img = F.interpolate(out_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0) | |
| if cont_img.shape[1:] != (target_size[0], target_size[1]): | |
| cont_img = F.interpolate(cont_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0) | |
| # Normalize for VGG | |
| out_norm = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| )(out_img) | |
| cont_norm = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| )(cont_img) | |
| output_vgg.append(out_norm) | |
| content_vgg.append(cont_norm) | |
| output_vgg = torch.stack(output_vgg) | |
| content_vgg = torch.stack(content_vgg) | |
| # Ensure style tensor matches batch size and dimensions | |
| style_vgg = style_tensor.expand(output_vgg.size(0), -1, -1, -1) | |
| if style_vgg.shape[2:] != output_vgg.shape[2:]: | |
| style_vgg = F.interpolate(style_vgg, size=output_vgg.shape[2:], mode='bilinear', align_corners=False) | |
| # Calculate loss | |
| loss, content_loss, style_loss = perceptual_loss( | |
| output_vgg, content_vgg, style_vgg, | |
| content_weight=content_weight, style_weight=style_weight | |
| ) | |
| # Backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| total_steps += 1 | |
| # Progress callback | |
| if progress_callback and total_steps % 10 == 0: | |
| progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs | |
| aug_info = f" (aug {num_content_images}→{len(dataset)})" if num_content_images < 20 else "" | |
| blocks_info = f", {n_residual_blocks} blocks" | |
| progress_callback(progress, f"Epoch {epoch+1}/{epochs}{aug_info}{blocks_info}, Loss: {loss.item():.4f}") | |
| # Save checkpoint | |
| if (epoch + 1) % int(save_interval) == 0: | |
| checkpoint_path = f'{self.models_dir}/{model_name}_epoch_{epoch+1}.pth' | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': epoch_loss / len(dataloader), | |
| 'n_residual_blocks': n_residual_blocks | |
| }, checkpoint_path) | |
| print(f"Saved checkpoint: {checkpoint_path}") | |
| # Save final model | |
| final_path = f'{self.models_dir}/{model_name}_final.pth' | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'n_residual_blocks': n_residual_blocks | |
| }, final_path) | |
| print(f"Training complete! Model saved to: {final_path}") | |
| # Add to lightweight models | |
| self.lightweight_models[model_name] = model | |
| return model | |
| def load_lightweight_model(self, model_path): | |
| """Load a trained lightweight model""" | |
| try: | |
| # Load directly to the target device | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| # Check if n_residual_blocks is saved | |
| if isinstance(state_dict, dict) and 'n_residual_blocks' in state_dict: | |
| n_blocks = state_dict['n_residual_blocks'] | |
| print(f"Found saved architecture: {n_blocks} residual blocks") | |
| else: | |
| # Try to detect from state dict | |
| if 'model_state_dict' in state_dict: | |
| model_state = state_dict['model_state_dict'] | |
| else: | |
| model_state = state_dict | |
| res_block_keys = [k for k in model_state.keys() if 'res_blocks' in k and 'weight' in k] | |
| n_blocks = len(set([k.split('.')[1] for k in res_block_keys if k.startswith('res_blocks')])) or 5 | |
| print(f"Detected {n_blocks} residual blocks from model structure") | |
| # Create model with detected architecture | |
| model = LightweightStyleNet(n_residual_blocks=n_blocks).to(self.device) | |
| # Load the weights | |
| if 'model_state_dict' in state_dict: | |
| model.load_state_dict(state_dict['model_state_dict']) | |
| else: | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Verify model is on correct device | |
| print(f"Lightweight model loaded on: {next(model.parameters()).device}") | |
| return model | |
| except Exception as e: | |
| print(f"Error loading lightweight model: {e}") | |
| # Try with default 5 blocks | |
| try: | |
| print("Attempting to load with default 5 residual blocks...") | |
| model = LightweightStyleNet(n_residual_blocks=5).to(self.device) | |
| if model_path.endswith('.pth'): | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| if 'model_state_dict' in state_dict: | |
| model.load_state_dict(state_dict['model_state_dict']) | |
| else: | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| print(f"Fallback model loaded on: {next(model.parameters()).device}") | |
| return model | |
| except: | |
| return None | |
| # Inside the StyleTransferSystem class, add these methods: | |
| def _create_linear_weight(self, width, height, overlap): | |
| """Create linear blending weights for tile edges""" | |
| weight = np.ones((height, width, 1), dtype=np.float32) | |
| if overlap > 0: | |
| # Create gradients for each edge | |
| for i in range(overlap): | |
| alpha = i / overlap | |
| # Top edge | |
| weight[i, :] *= alpha | |
| # Bottom edge | |
| weight[-i-1, :] *= alpha | |
| # Left edge | |
| weight[:, i] *= alpha | |
| # Right edge | |
| weight[:, -i-1] *= alpha | |
| return weight | |
| def _create_gaussian_weight(self, width, height, overlap): | |
| """Create Gaussian blending weights for smoother transitions""" | |
| weight = np.ones((height, width), dtype=np.float32) | |
| # Create 2D Gaussian centered in the tile | |
| y, x = np.ogrid[:height, :width] | |
| center_y, center_x = height / 2, width / 2 | |
| # Distance from center | |
| dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2) | |
| # Gaussian falloff starting from the edges | |
| max_dist = min(height, width) / 2 | |
| sigma = max_dist / 2 # Adjust for smoother/sharper transitions | |
| # Apply Gaussian only near edges | |
| edge_dist = np.minimum( | |
| np.minimum(y, height - 1 - y), | |
| np.minimum(x, width - 1 - x) | |
| ) | |
| # Weight is 1 in center, Gaussian falloff near edges | |
| weight = np.where( | |
| edge_dist < overlap, | |
| np.exp(-0.5 * ((overlap - edge_dist) / (overlap/3))**2), | |
| 1.0 | |
| ) | |
| return weight.reshape(height, width, 1) | |
| def apply_lightweight_style(self, image, model, intensity=1.0): | |
| """Apply style using a lightweight model""" | |
| if image is None or model is None: | |
| return None | |
| try: | |
| # Ensure model is on the correct device | |
| model = model.to(self.device) | |
| model.eval() | |
| original_size = image.size | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(256), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| if self.device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| output = model(img_tensor) | |
| if self.device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| output_img = self.inverse_transform(output.squeeze(0).cpu()) | |
| output_img = output_img.resize(original_size, Image.LANCZOS) | |
| if intensity < 1.0: | |
| output_array = np.array(output_img, dtype=np.float32) | |
| original_array = np.array(image, dtype=np.float32) | |
| blended = original_array * (1 - intensity) + output_array * intensity | |
| output_img = Image.fromarray(blended.astype(np.uint8)) | |
| return output_img | |
| except Exception as e: | |
| print(f"Error applying lightweight style: {e}") | |
| print(f"Device: {self.device}") | |
| print(f"Model device: {next(model.parameters()).device}") | |
| return None | |
| def blend_styles(self, image, style_configs, blend_mode="additive"): | |
| """Apply multiple styles with different blending modes""" | |
| if not image or not style_configs: | |
| return image | |
| original = np.array(image, dtype=np.float32) | |
| styled_images = [] | |
| weights = [] | |
| for style_type, model_key, intensity in style_configs: | |
| if intensity <= 0: | |
| continue | |
| if style_type == 'cyclegan': | |
| styled = self.apply_cyclegan_style(image, model_key, 1.0) | |
| elif style_type == 'lightweight' and model_key in self.lightweight_models: | |
| styled = self.apply_lightweight_style(image, self.lightweight_models[model_key], 1.0) | |
| else: | |
| continue | |
| if styled: | |
| styled_images.append(np.array(styled, dtype=np.float32)) | |
| weights.append(intensity) | |
| if not styled_images: | |
| return image | |
| # Apply blending | |
| if blend_mode == "average": | |
| result = np.zeros_like(original) | |
| total_weight = sum(weights) | |
| for img, weight in zip(styled_images, weights): | |
| result += img * (weight / total_weight) | |
| elif blend_mode == "additive": | |
| result = original.copy() | |
| for img, weight in zip(styled_images, weights): | |
| transformation = img - original | |
| result = result + transformation * weight | |
| elif blend_mode == "maximum": | |
| result = original.copy() | |
| for img, weight in zip(styled_images, weights): | |
| transformation = (img - original) * weight | |
| current_diff = result - original | |
| mask = np.abs(transformation) > np.abs(current_diff) | |
| result[mask] = original[mask] + transformation[mask] | |
| elif blend_mode == "overlay": | |
| result = original.copy() | |
| for img, weight in zip(styled_images, weights): | |
| overlay = np.zeros_like(result) | |
| mask = result < 128 | |
| overlay[mask] = 2 * img[mask] * result[mask] / 255.0 | |
| overlay[~mask] = 255 - 2 * (255 - img[~mask]) * (255 - result[~mask]) / 255.0 | |
| result = result * (1 - weight) + overlay * weight | |
| else: # "screen" mode | |
| result = original.copy() | |
| for img, weight in zip(styled_images, weights): | |
| screened = 255 - ((255 - result) * (255 - img) / 255.0) | |
| if weight > 1.0: | |
| diff = screened - result | |
| result = result + diff * weight | |
| else: | |
| result = result * (1 - weight) + screened * weight | |
| return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) | |
| def apply_regional_styles(self, image, combined_mask, regions, base_style_configs=None, blend_mode="additive"): | |
| """Apply different styles to painted regions using a combined mask""" | |
| if not regions: | |
| if base_style_configs: | |
| return self.blend_styles(image, base_style_configs, blend_mode) | |
| return image | |
| original_size = image.size | |
| result = np.array(image, dtype=np.float32) | |
| # Apply base style if provided | |
| if base_style_configs: | |
| base_styled = self.blend_styles(image, base_style_configs, blend_mode) | |
| result = np.array(base_styled, dtype=np.float32) | |
| # Resize mask to match original image if needed | |
| if combined_mask is not None and combined_mask.shape[:2] != (original_size[1], original_size[0]): | |
| # Resize the combined mask to match the original image | |
| combined_mask_pil = Image.fromarray(combined_mask.astype(np.uint8)) | |
| combined_mask_resized = combined_mask_pil.resize(original_size, Image.NEAREST) | |
| combined_mask = np.array(combined_mask_resized) | |
| # Apply each region | |
| for i, region in enumerate(regions): | |
| if region['style'] is None: | |
| continue | |
| # Get model key for this region's style | |
| model_key = None | |
| for key, info in self.cyclegan_models.items(): | |
| if info['name'] == region['style']: | |
| model_key = key | |
| break | |
| if not model_key: | |
| continue | |
| # Apply style to whole image | |
| style_configs = [('cyclegan', model_key, region['intensity'])] | |
| styled = self.blend_styles(image, style_configs, blend_mode) | |
| styled_array = np.array(styled, dtype=np.float32) | |
| # Create mask for this region from combined mask | |
| if combined_mask is not None: | |
| # Region masks are identified by their color index | |
| region_mask = (combined_mask == (i + 1)).astype(np.float32) | |
| # Ensure mask has same shape as image | |
| if len(region_mask.shape) == 2: | |
| region_mask_3ch = np.stack([region_mask] * 3, axis=2) | |
| else: | |
| region_mask_3ch = region_mask | |
| # Blend using mask | |
| result = result * (1 - region_mask_3ch) + styled_array * region_mask_3ch | |
| return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) | |
| def train_adain_model(self, style_images, content_dir, model_name, | |
| epochs=30, batch_size=4, lr=1e-4, | |
| save_interval=5, style_weight=10.0, content_weight=1.0, | |
| progress_callback=None): | |
| """Train an AdaIN-based style transfer model""" | |
| model = AdaINStyleTransfer().to(self.device) | |
| optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr) | |
| # Add learning rate scheduler | |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8) | |
| print(f"Training AdaIN model") | |
| print(f"Training device: {self.device}") | |
| # Verify model is on GPU | |
| if self.device.type == 'cuda': | |
| print(f"Model on GPU: {next(model.decoder.parameters()).device}") | |
| print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB") | |
| # Prepare style images - INCREASED SIZE | |
| style_transform = transforms.Compose([ | |
| transforms.Resize(600), # Increased from 512 | |
| transforms.RandomCrop(512), # Increased from 256 | |
| transforms.RandomHorizontalFlip(p=0.5), # Add augmentation | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| style_tensors = [] | |
| # Create multiple augmented versions of each style image | |
| for style_img in style_images: | |
| # Generate 5 augmented versions per style image | |
| for _ in range(5): | |
| style_tensor = style_transform(style_img).unsqueeze(0).to(self.device) | |
| style_tensors.append(style_tensor) | |
| print(f"Created {len(style_tensors)} augmented style samples from {len(style_images)} images") | |
| # Prepare content dataset - INCREASED SIZE | |
| content_transform = transforms.Compose([ | |
| transforms.Resize(600), # Increased from 512 | |
| transforms.RandomCrop(512), # Increased from 256 | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| dataset = StyleTransferDataset(content_dir, transform=content_transform) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) | |
| print(f"Training configuration:") | |
| print(f" - Style images: {len(style_tensors)}") | |
| print(f" - Content images: {len(dataset)}") | |
| print(f" - Batch size: {batch_size}") | |
| print(f" - Epochs: {epochs}") | |
| print(f" - Training resolution: 512x512") # Updated | |
| # Loss network (VGG for perceptual loss) - USE MULTIPLE LAYERS | |
| class MultiLayerVGG(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features | |
| self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1 | |
| self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1 | |
| self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1 | |
| self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1 | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| h1 = self.slice1(x) | |
| h2 = self.slice2(h1) | |
| h3 = self.slice3(h2) | |
| h4 = self.slice4(h3) | |
| return [h1, h2, h3, h4] | |
| loss_network = MultiLayerVGG().to(self.device).eval() | |
| mse_loss = nn.MSELoss() | |
| # Training loop | |
| model.train() | |
| model.encoder.eval() # Keep encoder frozen | |
| total_steps = 0 | |
| # Adjust style weight for better quality | |
| actual_style_weight = style_weight * 10 # Multiply by 10 for better style transfer | |
| for epoch in range(epochs): | |
| epoch_loss = 0 | |
| for batch_idx, content_batch in enumerate(dataloader): | |
| content_batch = content_batch.to(self.device) | |
| # Randomly select style images for this batch | |
| batch_style = [] | |
| for _ in range(content_batch.size(0)): | |
| style_idx = np.random.randint(0, len(style_tensors)) | |
| batch_style.append(style_tensors[style_idx]) | |
| batch_style = torch.cat(batch_style, dim=0) | |
| # Forward pass | |
| output = model(content_batch, batch_style) | |
| # Multi-layer content and style loss | |
| with torch.no_grad(): | |
| content_feats = loss_network(content_batch) | |
| style_feats = loss_network(batch_style) | |
| output_feats = loss_network(output) | |
| # Content loss - only from relu4_1 | |
| content_loss = mse_loss(output_feats[-1], content_feats[-1]) | |
| # Style loss - from multiple layers | |
| style_loss = 0 | |
| style_weights = [0.2, 0.3, 0.5, 1.0] # Give more weight to higher layers | |
| def gram_matrix(feat): | |
| b, c, h, w = feat.size() | |
| feat = feat.view(b, c, h * w) | |
| gram = torch.bmm(feat, feat.transpose(1, 2)) | |
| return gram / (c * h * w) | |
| for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)): | |
| output_gram = gram_matrix(output_feat) | |
| style_gram = gram_matrix(style_feat) | |
| style_loss += weight * mse_loss(output_gram, style_gram) | |
| style_loss /= len(style_weights) | |
| # Total loss | |
| loss = content_weight * content_loss + actual_style_weight * style_loss | |
| # Backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| # Gradient clipping for stability | |
| torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0) | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| total_steps += 1 | |
| # Progress callback | |
| if progress_callback and total_steps % 10 == 0: | |
| progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs | |
| progress_callback(progress, | |
| f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} " | |
| f"(Content: {content_loss.item():.4f}, Style: {style_loss.item():.4f})") | |
| # Step scheduler | |
| scheduler.step() | |
| # Save checkpoint | |
| if (epoch + 1) % save_interval == 0: | |
| checkpoint_path = f'{self.models_dir}/{model_name}_epoch_{epoch+1}.pth' | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'loss': epoch_loss / len(dataloader), | |
| 'model_type': 'adain' | |
| }, checkpoint_path) | |
| print(f"Saved checkpoint: {checkpoint_path}") | |
| # Save final model | |
| final_path = f'{self.models_dir}/{model_name}_final.pth' | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'model_type': 'adain' | |
| }, final_path) | |
| print(f"Training complete! Model saved to: {final_path}") | |
| # Add to lightweight models | |
| self.lightweight_models[model_name] = model | |
| return model | |
| # Update these methods in your StyleTransferSystem class: | |
| def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False): | |
| """Apply AdaIN-based style transfer with optional tiling""" | |
| # Use tiling for large images to maintain quality | |
| if use_tiling and (content_image.width > 768 or content_image.height > 768): | |
| return self.apply_adain_style_tiled( | |
| content_image, style_image, model, alpha, | |
| tile_size=512, # Increased from 256 | |
| overlap=64, # Increased overlap | |
| blend_mode='gaussian' | |
| ) | |
| if content_image is None or style_image is None or model is None: | |
| return None | |
| try: | |
| model = model.to(self.device) | |
| model.eval() | |
| original_size = content_image.size | |
| # Use higher resolution - find optimal size while maintaining aspect ratio | |
| max_dim = 768 # Increased from 256 | |
| w, h = content_image.size | |
| if w > h: | |
| new_w = min(w, max_dim) | |
| new_h = int(h * new_w / w) | |
| else: | |
| new_h = min(h, max_dim) | |
| new_w = int(w * new_h / h) | |
| # Ensure dimensions are divisible by 8 for better compatibility | |
| new_w = (new_w // 8) * 8 | |
| new_h = (new_h // 8) * 8 | |
| # Transform for AdaIN (VGG normalization) | |
| transform = transforms.Compose([ | |
| transforms.Resize((new_h, new_w)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| content_tensor = transform(content_image).unsqueeze(0).to(self.device) | |
| style_tensor = transform(style_image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| output = model(content_tensor, style_tensor, alpha=alpha) | |
| # Denormalize | |
| output = output.squeeze(0).cpu() | |
| output = output * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| output = output + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| output = torch.clamp(output, 0, 1) | |
| # Convert to PIL | |
| output_img = transforms.ToPILImage()(output) | |
| output_img = output_img.resize(original_size, Image.LANCZOS) | |
| return output_img | |
| except Exception as e: | |
| print(f"Error applying AdaIN style: {e}") | |
| traceback.print_exc() | |
| return None | |
| def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0, | |
| tile_size=256, overlap=32, blend_mode='linear'): | |
| """ | |
| Apply AdaIN style transfer using tiling for high-quality results. | |
| Processes image in overlapping tiles to maintain quality. | |
| """ | |
| if content_image is None or style_image is None or model is None: | |
| return None | |
| try: | |
| model = model.to(self.device) | |
| model.eval() | |
| # INCREASED TILE SIZE FOR BETTER QUALITY | |
| tile_size = 512 # Override input to use 512 | |
| overlap = 64 # Increase overlap proportionally | |
| # Prepare transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize((tile_size, tile_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Process style image once (at tile size) | |
| style_tensor = transform(style_image).unsqueeze(0).to(self.device) | |
| # Get dimensions | |
| w, h = content_image.size | |
| # Calculate tile positions with overlap | |
| stride = tile_size - overlap | |
| tiles_x = list(range(0, w - tile_size + 1, stride)) | |
| tiles_y = list(range(0, h - tile_size + 1, stride)) | |
| # Ensure we cover the entire image | |
| if not tiles_x or tiles_x[-1] + tile_size < w: | |
| tiles_x.append(max(0, w - tile_size)) | |
| if not tiles_y or tiles_y[-1] + tile_size < h: | |
| tiles_y.append(max(0, h - tile_size)) | |
| # If image is smaller than tile size, just process normally | |
| if w <= tile_size and h <= tile_size: | |
| return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False) | |
| print(f"Processing {len(tiles_x) * len(tiles_y)} tiles of size {tile_size}x{tile_size}") | |
| # Initialize output and weight arrays | |
| output_array = np.zeros((h, w, 3), dtype=np.float32) | |
| weight_array = np.zeros((h, w, 1), dtype=np.float32) | |
| # Process each tile | |
| with torch.no_grad(): | |
| for y_idx, y in enumerate(tiles_y): | |
| for x_idx, x in enumerate(tiles_x): | |
| # Extract tile | |
| tile = content_image.crop((x, y, x + tile_size, y + tile_size)) | |
| # Transform tile | |
| tile_tensor = transform(tile).unsqueeze(0).to(self.device) | |
| # Apply AdaIN to tile | |
| styled_tensor = model(tile_tensor, style_tensor, alpha=alpha) | |
| # Denormalize | |
| styled_tensor = styled_tensor.squeeze(0).cpu() | |
| denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| styled_tensor = styled_tensor * denorm_std + denorm_mean | |
| styled_tensor = torch.clamp(styled_tensor, 0, 1) | |
| # Convert to numpy | |
| styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255 | |
| # Create weight mask for blending - use gaussian by default for better quality | |
| weight = self._create_gaussian_weight(tile_size, tile_size, overlap) | |
| # Add to output with weights | |
| output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight | |
| weight_array[y:y+tile_size, x:x+tile_size] += weight | |
| # Normalize by weights | |
| output_array = output_array / (weight_array + 1e-8) | |
| output_array = np.clip(output_array, 0, 255).astype(np.uint8) | |
| return Image.fromarray(output_array) | |
| except Exception as e: | |
| print(f"Error in tiled AdaIN processing: {e}") | |
| traceback.print_exc() | |
| # Fallback to standard processing | |
| return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False) | |
| # =========================== | |
| # HELPER FUNCTIONS | |
| # =========================== | |
| def resize_image_for_display(image, max_width=800, max_height=600): | |
| """Resize image for display while maintaining aspect ratio""" | |
| width, height = image.size | |
| # Calculate scaling factor | |
| width_scale = max_width / width | |
| height_scale = max_height / height | |
| scale = min(width_scale, height_scale) | |
| # Only scale down, not up | |
| if scale < 1: | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| return image.resize((new_width, new_height), Image.LANCZOS) | |
| return image | |
| def combine_region_masks(canvas_results, canvas_size): | |
| """Combine multiple region masks into a single mask with different values for each region""" | |
| combined_mask = np.zeros(canvas_size[:2], dtype=np.uint8) | |
| for i, canvas_data in enumerate(canvas_results): | |
| if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None: | |
| # Extract alpha channel as mask | |
| mask = canvas_data.image_data[:, :, 3] > 0 | |
| # Assign region index (1-based) to mask | |
| combined_mask[mask] = i + 1 | |
| return combined_mask | |
| def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False): | |
| """Apply AdaIN style transfer to a painted region only""" | |
| if content_image is None or style_image is None or model is None: | |
| return None | |
| try: | |
| # Get the mask from canvas | |
| if canvas_result is None or canvas_result.image_data is None: | |
| # No mask painted, apply to whole image | |
| return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) | |
| # Extract mask from canvas | |
| mask_data = canvas_result.image_data[:, :, 3] # Alpha channel | |
| mask = mask_data > 0 | |
| # Resize mask to match original image size | |
| original_size = content_image.size | |
| display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0]) | |
| if original_size != display_size: | |
| # Convert mask to PIL image for resizing | |
| mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L') | |
| mask_pil = mask_pil.resize(original_size, Image.NEAREST) | |
| mask = np.array(mask_pil) > 128 | |
| # Apply feathering to mask edges if requested | |
| if feather_radius > 0: | |
| from scipy.ndimage import gaussian_filter | |
| mask_float = mask.astype(np.float32) | |
| mask_float = gaussian_filter(mask_float, sigma=feather_radius) | |
| mask_float = np.clip(mask_float, 0, 1) | |
| else: | |
| mask_float = mask.astype(np.float32) | |
| # Apply style to entire image with tiling option | |
| styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) | |
| if styled_full is None: | |
| return None | |
| # Blend original and styled based on mask | |
| original_array = np.array(content_image, dtype=np.float32) | |
| styled_array = np.array(styled_full, dtype=np.float32) | |
| # Expand mask to 3 channels | |
| mask_3ch = np.stack([mask_float] * 3, axis=2) | |
| # Blend | |
| result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch | |
| result_array = np.clip(result_array, 0, 255).astype(np.uint8) | |
| return Image.fromarray(result_array) | |
| except Exception as e: | |
| print(f"Error applying regional AdaIN style: {e}") | |
| traceback.print_exc() | |
| return None | |
| # =========================== | |
| # INITIALIZE SYSTEM AND API | |
| # =========================== | |
| def load_system(): | |
| return StyleTransferSystem() | |
| def get_unsplash_api(): | |
| return UnsplashAPI() | |
| system = load_system() | |
| unsplash = get_unsplash_api() | |
| # Get style choices | |
| style_choices = sorted([info['name'] for info in system.cyclegan_models.values()]) | |
| # =========================== | |
| # STREAMLIT APP | |
| # =========================== | |
| # Main app | |
| st.title("Style Transfer") | |
| st.markdown("Image and video style transfer with CycleGAN and custom training capabilities") | |
| # Sidebar for global settings | |
| with st.sidebar: | |
| st.header("Settings") | |
| # GPU status | |
| if torch.cuda.is_available(): | |
| gpu_info = torch.cuda.get_device_properties(0) | |
| st.success(f"GPU: {gpu_info.name}") | |
| # Show memory usage | |
| total_memory = gpu_info.total_memory / 1e9 | |
| used_memory = torch.cuda.memory_allocated() / 1e9 | |
| free_memory = total_memory - used_memory | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("Total Memory", f"{total_memory:.2f} GB") | |
| with col2: | |
| st.metric("Used Memory", f"{used_memory:.2f} GB") | |
| # Memory usage bar | |
| memory_percentage = (used_memory / total_memory) * 100 | |
| st.progress(memory_percentage / 100) | |
| st.caption(f"Free: {free_memory:.2f} GB ({100-memory_percentage:.1f}%)") | |
| # Check if GPU is actually being used | |
| if used_memory < 0.1: | |
| st.warning("GPU detected but not in use. Models may be running on CPU.") | |
| # Force GPU button | |
| if st.button("Force GPU Reset"): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| st.rerun() | |
| else: | |
| st.warning("Running on CPU (GPU not available)") | |
| st.caption("For faster processing, use a GPU-enabled environment") | |
| st.markdown("---") | |
| st.markdown("### Quick Guide") | |
| st.markdown(""" | |
| - **Style Transfer**: Apply artistic styles to images | |
| - **Regional Transform**: Paint areas for local effects | |
| - **Video Processing**: Apply styles to videos | |
| - **Train Custom**: Create your own style models | |
| - **Batch Process**: Process multiple images | |
| """) | |
| # Unsplash API status | |
| st.markdown("---") | |
| if unsplash.access_key: | |
| st.success("Unsplash API Connected") | |
| else: | |
| st.info("Add Unsplash API key for image search") | |
| # Debug mode | |
| st.markdown("---") | |
| if st.checkbox("🐛 Debug Mode"): | |
| st.code(f""" | |
| Device: {device} | |
| CUDA Available: {torch.cuda.is_available()} | |
| CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'} | |
| PyTorch Version: {torch.__version__} | |
| Models Loaded: {len(system.loaded_generators)} | |
| """) | |
| # Main tabs | |
| tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([ | |
| "Style Transfer", | |
| "Regional Transform", | |
| "Video Processing", | |
| "Train Custom Style", | |
| "Batch Processing", | |
| "Documentation" | |
| ]) | |
| # TAB 1: Style Transfer (with Unsplash integration) | |
| with tab1: | |
| # Unsplash Search Section | |
| with st.expander("Search Unsplash for Images", expanded=False): | |
| if not unsplash.access_key: | |
| st.info(""" | |
| To enable Unsplash search: | |
| 1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers) | |
| 2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY` | |
| """) | |
| else: | |
| search_col1, search_col2, search_col3 = st.columns([3, 1, 1]) | |
| with search_col1: | |
| search_query = st.text_input("Search for images", placeholder="e.g., landscape, portrait, abstract art") | |
| with search_col2: | |
| orientation = st.selectbox("Orientation", ["all", "landscape", "portrait", "squarish"]) | |
| with search_col3: | |
| search_button = st.button("Search", use_container_width=True) | |
| # Random photos button | |
| if st.button("Get Random Photos"): | |
| with st.spinner("Loading random photos..."): | |
| results, error = unsplash.get_random_photos(count=12) | |
| if error: | |
| st.error(f"Error: {error}") | |
| elif results: | |
| # Handle both single photo and array of photos | |
| photos = results if isinstance(results, list) else [results] | |
| st.session_state['unsplash_results'] = photos | |
| st.success(f"Loaded {len(photos)} random photos") | |
| # Search functionality | |
| if search_button and search_query: | |
| with st.spinner(f"Searching for '{search_query}'..."): | |
| orientation_param = None if orientation == "all" else orientation | |
| results, error = unsplash.search_photos(search_query, per_page=12, orientation=orientation_param) | |
| if error: | |
| st.error(f"Error: {error}") | |
| elif results and results.get('results'): | |
| st.session_state['unsplash_results'] = results['results'] | |
| st.success(f"Found {results['total']} images") | |
| else: | |
| st.info("No images found. Try a different search term.") | |
| # Display results | |
| if 'unsplash_results' in st.session_state and st.session_state['unsplash_results']: | |
| st.markdown("### Search Results") | |
| # Display in a 4-column grid | |
| cols = st.columns(4) | |
| for idx, photo in enumerate(st.session_state['unsplash_results'][:12]): | |
| with cols[idx % 4]: | |
| # Show thumbnail | |
| st.image(photo['urls']['thumb'], use_column_width=True) | |
| # Photo info | |
| st.caption(f"By {photo['user']['name']}") | |
| # Use button | |
| if st.button("Use This", key=f"use_unsplash_{photo['id']}"): | |
| with st.spinner("Loading image..."): | |
| # Download regular size | |
| img = unsplash.download_photo(photo['urls']['regular']) | |
| if img: | |
| # Store in session state | |
| st.session_state['current_image'] = img | |
| st.session_state['image_source'] = f"Unsplash: {photo['user']['name']}" | |
| st.session_state['unsplash_photo'] = photo | |
| # Trigger download tracking (required by Unsplash) | |
| if 'links' in photo and 'download_location' in photo['links']: | |
| unsplash.trigger_download(photo['links']['download_location']) | |
| st.success("Image loaded!") | |
| st.rerun() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.header("Input") | |
| # Image source selection | |
| image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True) | |
| # Initialize input_image to None | |
| input_image = None | |
| if image_source == "Upload": | |
| uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg']) | |
| if uploaded_file: | |
| input_image = Image.open(uploaded_file).convert('RGB') | |
| st.session_state['current_image'] = input_image | |
| st.session_state['image_source'] = "Uploaded" | |
| else: | |
| # Handle Unsplash selection | |
| if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): | |
| input_image = st.session_state['current_image'] | |
| else: | |
| st.info("Search for an image above") | |
| if input_image: | |
| # Display the image | |
| display_img = resize_image_for_display(input_image, max_width=600, max_height=400) | |
| st.image(display_img, caption=st.session_state.get('image_source', 'Image'), use_column_width=True) | |
| # Attribution for Unsplash images | |
| if 'unsplash_photo' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): | |
| photo = st.session_state['unsplash_photo'] | |
| st.markdown(f"Photo by [{photo['user']['name']}]({photo['user']['links']['html']}) on [Unsplash]({photo['links']['html']})") | |
| st.subheader("Style Configuration") | |
| # Up to 3 styles | |
| num_styles = st.number_input("Number of styles to apply", 1, 3, 1) | |
| style_configs = [] | |
| for i in range(num_styles): | |
| with st.expander(f"Style {i+1}", expanded=(i==0)): | |
| style = st.selectbox(f"Select style", style_choices, key=f"style_{i}") | |
| intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"intensity_{i}") | |
| if style and intensity > 0: | |
| model_key = None | |
| for key, info in system.cyclegan_models.items(): | |
| if info['name'] == style: | |
| model_key = key | |
| break | |
| if model_key: | |
| style_configs.append(('cyclegan', model_key, intensity)) | |
| blend_mode = st.selectbox("Blend Mode", | |
| ["additive", "average", "maximum", "overlay", "screen"], | |
| index=0) | |
| if st.button("Apply Styles", type="primary", use_container_width=True): | |
| if style_configs: | |
| with st.spinner("Applying styles..."): | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Process with progress updates | |
| for i, (_, key, intensity) in enumerate(style_configs): | |
| model_name = system.cyclegan_models[key]['name'] | |
| progress = (i + 1) / len(style_configs) | |
| progress_bar.progress(progress) | |
| status_text.text(f"Applying {model_name}...") | |
| result = system.blend_styles(input_image, style_configs, blend_mode) | |
| st.session_state['last_result'] = result | |
| st.session_state['last_style_configs'] = style_configs | |
| progress_bar.empty() | |
| status_text.empty() | |
| with col2: | |
| st.header("Result") | |
| if 'last_result' in st.session_state: | |
| st.image(st.session_state['last_result'], caption="Styled Image", use_column_width=True) | |
| # Download button | |
| buf = io.BytesIO() | |
| st.session_state['last_result'].save(buf, format='PNG') | |
| st.download_button( | |
| label="Download Result", | |
| data=buf.getvalue(), | |
| file_name=f"styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png", | |
| mime="image/png" | |
| ) | |
| # TAB 2: Regional Transform | |
| with tab2: | |
| st.header("Regional Style Transform") | |
| st.markdown("Paint different regions to apply different styles locally") | |
| # Initialize session state | |
| if 'regions' not in st.session_state: | |
| st.session_state.regions = [] | |
| if 'canvas_results' not in st.session_state: | |
| st.session_state.canvas_results = {} | |
| if 'regional_image_original' not in st.session_state: | |
| st.session_state.regional_image_original = None | |
| if 'canvas_ready' not in st.session_state: | |
| st.session_state.canvas_ready = True | |
| if 'last_applied_regions' not in st.session_state: | |
| st.session_state.last_applied_regions = None | |
| if 'canvas_key_base' not in st.session_state: | |
| st.session_state.canvas_key_base = 0 | |
| col1, col2 = st.columns([2, 3]) | |
| # Define variables at the top level of tab2 | |
| use_base = False | |
| base_style = None | |
| base_intensity = 1.0 | |
| regional_blend_mode = "additive" | |
| with col1: | |
| # Image source selection | |
| regional_image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True, key="regional_image_source") | |
| if regional_image_source == "Upload": | |
| uploaded_regional = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'], key="regional_upload") | |
| if uploaded_regional: | |
| # Load and store original image | |
| regional_image_original = Image.open(uploaded_regional).convert('RGB') | |
| st.session_state.regional_image_original = regional_image_original | |
| else: | |
| # Use Unsplash image if available | |
| if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): | |
| st.session_state.regional_image_original = st.session_state['current_image'] | |
| st.success("Using Unsplash image") | |
| else: | |
| st.info("Please search and select an image from the Style Transfer tab first") | |
| if st.session_state.regional_image_original: | |
| # Display the original image | |
| display_img = resize_image_for_display(st.session_state.regional_image_original, max_width=400, max_height=300) | |
| st.image(display_img, caption="Original Image", use_column_width=True) | |
| st.subheader("Define Regions") | |
| # Base style (optional) | |
| with st.expander("Base Style (Optional)", expanded=False): | |
| use_base = st.checkbox("Apply base style to entire image") | |
| if use_base: | |
| base_style = st.selectbox("Base style", style_choices, key="base_style") | |
| base_intensity = st.slider("Base intensity", 0.0, 2.0, 1.0, key="base_intensity") | |
| # Region management | |
| col_btn1, col_btn2, col_btn3 = st.columns(3) | |
| with col_btn1: | |
| if st.button("Add Region", use_container_width=True): | |
| new_region = { | |
| 'id': len(st.session_state.regions), | |
| 'style': style_choices[0] if style_choices else None, | |
| 'intensity': 1.0, | |
| 'color': f"hsla({len(st.session_state.regions) * 60}, 70%, 50%, 0.5)" | |
| } | |
| st.session_state.regions.append(new_region) | |
| st.session_state.canvas_ready = True | |
| st.rerun() | |
| with col_btn2: | |
| if st.button("Clear All", use_container_width=True): | |
| st.session_state.regions = [] | |
| st.session_state.canvas_results = {} | |
| if 'regional_result' in st.session_state: | |
| del st.session_state['regional_result'] | |
| st.session_state.canvas_ready = True | |
| st.session_state.canvas_key_base = 0 | |
| st.rerun() | |
| with col_btn3: | |
| if st.button("Reset Result", use_container_width=True): | |
| if 'regional_result' in st.session_state: | |
| del st.session_state['regional_result'] | |
| st.session_state.canvas_ready = True | |
| st.rerun() | |
| # Configure each region | |
| for i, region in enumerate(st.session_state.regions): | |
| with st.expander(f"Region {i+1} - {region.get('style', 'None')}", expanded=(i == len(st.session_state.regions) - 1)): | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| new_style = st.selectbox( | |
| "Style", | |
| style_choices, | |
| key=f"region_style_{i}", | |
| index=style_choices.index(region['style']) if region['style'] in style_choices else 0 | |
| ) | |
| region['style'] = new_style | |
| with col_b: | |
| region['intensity'] = st.slider( | |
| "Intensity", | |
| 0.0, 2.0, | |
| region.get('intensity', 1.0), | |
| key=f"region_intensity_{i}" | |
| ) | |
| if st.button(f"Remove Region {i+1}", key=f"remove_region_{i}"): | |
| # Remove the region | |
| st.session_state.regions.pop(i) | |
| # Rebuild canvas results with proper indices | |
| old_canvas_results = st.session_state.canvas_results.copy() | |
| st.session_state.canvas_results = {} | |
| for old_idx, result in old_canvas_results.items(): | |
| if old_idx < i: | |
| # Keep results before removed index | |
| st.session_state.canvas_results[old_idx] = result | |
| elif old_idx > i: | |
| # Shift results after removed index down by 1 | |
| st.session_state.canvas_results[old_idx - 1] = result | |
| st.session_state.canvas_ready = True | |
| st.session_state.canvas_key_base += 1 | |
| st.rerun() | |
| # Blend mode | |
| regional_blend_mode = st.selectbox("Blend Mode", | |
| ["additive", "average", "maximum", "overlay", "screen"], | |
| index=0, key="regional_blend") | |
| with col2: | |
| if st.session_state.regions and st.session_state.regional_image_original: | |
| st.subheader("Paint Regions") | |
| # Show workflow status | |
| if 'regional_result' in st.session_state: | |
| if st.session_state.canvas_ready: | |
| st.success("Edit Mode - Paint your regions and click 'Apply Regional Styles' when ready") | |
| else: | |
| st.info("Preview Mode - Click 'Continue Editing' to modify regions") | |
| else: | |
| st.info("Paint on the canvas below to define regions for each style") | |
| # Check if we're in edit mode | |
| if not st.session_state.canvas_ready: | |
| # Show a preview of the painted regions | |
| if 'regional_result' in st.session_state: | |
| st.subheader("Current Result") | |
| result_display = resize_image_for_display(st.session_state['regional_result'], max_width=600, max_height=400) | |
| st.image(result_display, caption="Applied Styles", use_column_width=True) | |
| # Create display image | |
| display_image = resize_image_for_display(st.session_state.regional_image_original, max_width=600, max_height=400) | |
| display_width, display_height = display_image.size | |
| # Info message | |
| st.info(f"Image resized to {display_width}x{display_height} for display. Original resolution will be used for processing.") | |
| # Get current region | |
| current_region_idx = st.selectbox( | |
| "Select region to paint", | |
| range(len(st.session_state.regions)), | |
| format_func=lambda x: f"Region {x+1}: {st.session_state.regions[x].get('style', 'None')}" | |
| ) | |
| current_region = st.session_state.regions[current_region_idx] | |
| col_draw1, col_draw2, col_draw3 = st.columns(3) | |
| with col_draw1: | |
| brush_size = st.slider("Brush Size", 1, 50, 15) | |
| with col_draw2: | |
| drawing_mode = st.selectbox("Tool", ["freedraw", "line", "rect", "circle"]) | |
| with col_draw3: | |
| if st.button("Clear This Region"): | |
| if current_region_idx in st.session_state.canvas_results: | |
| del st.session_state.canvas_results[current_region_idx] | |
| st.session_state.canvas_ready = True | |
| st.rerun() | |
| # Create combined background with all previous regions | |
| background_with_regions = display_image.copy() | |
| draw = ImageDraw.Draw(background_with_regions, 'RGBA') | |
| # Draw all regions on the background | |
| for i, region in enumerate(st.session_state.regions): | |
| if i in st.session_state.canvas_results: | |
| canvas_data = st.session_state.canvas_results[i] | |
| if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None: | |
| # Extract mask from canvas data | |
| mask = canvas_data.image_data[:, :, 3] > 0 | |
| # Create colored overlay for this region | |
| # Parse HSLA color more carefully | |
| color_str = region['color'].replace('hsla(', '').replace(')', '') | |
| color_parts = color_str.split(',') | |
| hue = int(color_parts[0]) | |
| # Convert HSL to RGB (simplified - assumes 70% saturation, 50% lightness) | |
| r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 0.7) | |
| color = (int(r*255), int(g*255), int(b*255)) | |
| opacity = 128 if i != current_region_idx else 200 | |
| # Draw mask on background | |
| for y in range(mask.shape[0]): | |
| for x in range(mask.shape[1]): | |
| if mask[y, x]: | |
| draw.point((x, y), fill=color + (opacity,)) | |
| # Canvas for current region | |
| stroke_color = current_region['color'].replace('0.5)', '0.8)') | |
| # Get initial drawing for current region | |
| initial_drawing = None | |
| if current_region_idx in st.session_state.canvas_results: | |
| canvas_data = st.session_state.canvas_results[current_region_idx] | |
| if canvas_data is not None and hasattr(canvas_data, 'json_data'): | |
| initial_drawing = canvas_data.json_data | |
| canvas_result = st_canvas( | |
| fill_color=stroke_color, | |
| stroke_width=brush_size, | |
| stroke_color=stroke_color, | |
| background_image=background_with_regions, | |
| update_streamlit=True, | |
| height=display_height, | |
| width=display_width, | |
| drawing_mode=drawing_mode, | |
| display_toolbar=True, | |
| initial_drawing=initial_drawing, | |
| key=f"regional_canvas_{current_region_idx}_{brush_size}_{drawing_mode}" | |
| ) | |
| # Save canvas result | |
| if canvas_result: | |
| st.session_state.canvas_results[current_region_idx] = canvas_result | |
| # Apply button | |
| if st.button("Apply Regional Styles", type="primary", use_container_width=True): | |
| with st.spinner("Applying regional styles..."): | |
| # Create combined mask from all canvas results | |
| combined_mask = combine_region_masks( | |
| [st.session_state.canvas_results.get(i) for i in range(len(st.session_state.regions))], | |
| (display_height, display_width) | |
| ) | |
| # Prepare base style configs if enabled | |
| base_configs = None | |
| if use_base and base_style: | |
| base_key = None | |
| for key, info in system.cyclegan_models.items(): | |
| if info['name'] == base_style: | |
| base_key = key | |
| break | |
| if base_key: | |
| base_configs = [('cyclegan', base_key, base_intensity)] | |
| # Apply regional styles using original image | |
| result = system.apply_regional_styles( | |
| st.session_state.regional_image_original, # Use original resolution | |
| combined_mask, | |
| st.session_state.regions, | |
| base_configs, | |
| regional_blend_mode | |
| ) | |
| st.session_state['regional_result'] = result | |
| # Show result with fixed size | |
| if 'regional_result' in st.session_state: | |
| st.subheader("Result") | |
| # Add display size control | |
| display_size = st.slider("Display Size", 300, 800, 600, 50, key="regional_display_size") | |
| # Fixed size display | |
| result_display = resize_image_for_display( | |
| st.session_state['regional_result'], | |
| max_width=display_size, | |
| max_height=display_size | |
| ) | |
| st.image(result_display, caption="Regional Styled Image") | |
| # Show actual dimensions | |
| st.caption(f"Original size: {st.session_state['regional_result'].size[0]}x{st.session_state['regional_result'].size[1]} pixels") | |
| # Download button | |
| buf = io.BytesIO() | |
| st.session_state['regional_result'].save(buf, format='PNG') | |
| st.download_button( | |
| label="Download Result", | |
| data=buf.getvalue(), | |
| file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png", | |
| mime="image/png" | |
| ) | |
| # TAB 3: Video Processing | |
| with tab3: | |
| st.header("Video Processing") | |
| if not VIDEO_PROCESSING_AVAILABLE: | |
| st.warning(""" | |
| Video processing requires OpenCV to be installed. | |
| To enable video processing, add `opencv-python` to your requirements.txt | |
| """) | |
| else: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| video_file = st.file_uploader("Upload Video", type=['mp4', 'avi', 'mov']) | |
| if video_file: | |
| st.video(video_file) | |
| st.subheader("Style Configuration") | |
| # Style selection (up to 2 for videos) | |
| video_styles = [] | |
| for i in range(2): | |
| with st.expander(f"Style {i+1}", expanded=(i==0)): | |
| style = st.selectbox(f"Select style", style_choices, key=f"video_style_{i}") | |
| intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"video_intensity_{i}") | |
| if style and intensity > 0: | |
| model_key = None | |
| for key, info in system.cyclegan_models.items(): | |
| if info['name'] == style: | |
| model_key = key | |
| break | |
| if model_key: | |
| video_styles.append(('cyclegan', model_key, intensity)) | |
| video_blend_mode = st.selectbox("Blend Mode", | |
| ["additive", "average", "maximum", "overlay", "screen"], | |
| index=0, key="video_blend") | |
| if st.button("Process Video", type="primary", use_container_width=True): | |
| if video_styles: | |
| with st.spinner("Processing video..."): | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| def progress_callback(p, msg): | |
| progress_bar.progress(p) | |
| status_text.text(msg) | |
| # Save uploaded file temporarily | |
| temp_input = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(video_file.name)[1]) | |
| temp_input.write(video_file.read()) | |
| temp_input.close() | |
| # Process video | |
| output_path = system.video_processor.process_video( | |
| temp_input.name, video_styles, video_blend_mode, progress_callback | |
| ) | |
| if output_path and os.path.exists(output_path): | |
| # Read the video file immediately | |
| try: | |
| with open(output_path, 'rb') as f: | |
| video_bytes = f.read() | |
| # Determine file extension | |
| file_ext = os.path.splitext(output_path)[1].lower() | |
| # Store in session state | |
| st.session_state['video_result_bytes'] = video_bytes | |
| st.session_state['video_result_ext'] = file_ext | |
| st.session_state['video_result_available'] = True | |
| st.session_state['video_is_mp4'] = (file_ext == '.mp4') | |
| st.success(f"Video processing complete! Format: {file_ext.upper()}") | |
| # Clean up files | |
| try: | |
| os.unlink(output_path) | |
| except: | |
| pass | |
| except Exception as e: | |
| st.error(f"Failed to read processed video: {str(e)}") | |
| st.session_state['video_result_available'] = False | |
| else: | |
| st.error("Failed to process video. Please try a different video or reduce the resolution.") | |
| st.session_state['video_result_available'] = False | |
| # Cleanup input file | |
| try: | |
| os.unlink(temp_input.name) | |
| except: | |
| pass | |
| progress_bar.empty() | |
| status_text.empty() | |
| else: | |
| st.warning("Please select at least one style") | |
| with col2: | |
| st.header("Result") | |
| if st.session_state.get('video_result_available', False) and 'video_result_bytes' in st.session_state: | |
| try: | |
| file_ext = st.session_state.get('video_result_ext', '.mp4') | |
| video_bytes = st.session_state['video_result_bytes'] | |
| # File info | |
| file_size_mb = len(video_bytes) / (1024 * 1024) | |
| # Try to display video | |
| if st.session_state.get('video_is_mp4', False): | |
| # For MP4, should work in browser | |
| st.video(video_bytes) | |
| st.success(f"Video ready! Size: {file_size_mb:.2f} MB") | |
| else: | |
| # For non-MP4, show info and download | |
| st.info(f"Video format ({file_ext}) may not play in browser. Please download to view.") | |
| # Show preview image if possible | |
| try: | |
| # Try to extract a frame for preview | |
| temp_preview = tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) | |
| temp_preview.write(video_bytes) | |
| temp_preview.close() | |
| cap = cv2.VideoCapture(temp_preview.name) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Convert frame to RGB and display | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| st.image(rgb_frame, caption="Video Preview (First Frame)", use_column_width=True) | |
| cap.release() | |
| os.unlink(temp_preview.name) | |
| except: | |
| pass | |
| # Always provide download button | |
| mime_types = { | |
| '.mp4': 'video/mp4', | |
| '.avi': 'video/x-msvideo', | |
| '.mov': 'video/quicktime' | |
| } | |
| mime_type = mime_types.get(file_ext, 'application/octet-stream') | |
| col_dl1, col_dl2 = st.columns(2) | |
| with col_dl1: | |
| st.download_button( | |
| label=f"Download Video ({file_ext.upper()})", | |
| data=video_bytes, | |
| file_name=f"styled_video_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}{file_ext}", | |
| mime=mime_type, | |
| use_container_width=True | |
| ) | |
| with col_dl2: | |
| if st.button("Clear Result", use_container_width=True): | |
| del st.session_state['video_result_bytes'] | |
| st.session_state['video_result_available'] = False | |
| if 'video_result_ext' in st.session_state: | |
| del st.session_state['video_result_ext'] | |
| if 'video_is_mp4' in st.session_state: | |
| del st.session_state['video_is_mp4'] | |
| st.rerun() | |
| # Info about playback | |
| if not st.session_state.get('video_is_mp4', False): | |
| st.caption("For best compatibility, download and use VLC or another video player.") | |
| except Exception as e: | |
| st.error(f"Error displaying video: {str(e)}") | |
| # Emergency download button | |
| if 'video_result_bytes' in st.session_state: | |
| st.download_button( | |
| label="📥 Download Video (Error occurred)", | |
| data=st.session_state['video_result_bytes'], | |
| file_name=f"styled_video_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4", | |
| mime="application/octet-stream" | |
| ) | |
| elif st.session_state.get('video_result_available', False): | |
| st.warning("Video data not found. Please process the video again.") | |
| if st.button("Clear State"): | |
| st.session_state['video_result_available'] = False | |
| st.rerun() | |
| # TAB 4: Training with AdaIN and Regional Application | |
| # TAB 4: Training with AdaIN and Regional Application | |
| with tab4: | |
| st.header("Train Custom Style with AdaIN") | |
| st.markdown("Train your own style transfer model using Adaptive Instance Normalization") | |
| # Initialize session state for content images | |
| if 'content_images_list' not in st.session_state: | |
| st.session_state.content_images_list = [] | |
| if 'adain_canvas_result' not in st.session_state: | |
| st.session_state.adain_canvas_result = None | |
| if 'adain_test_image' not in st.session_state: | |
| st.session_state.adain_test_image = None | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| with col1: | |
| st.subheader("Style Images") | |
| style_imgs = st.file_uploader("Upload 1-5 style images", type=['png', 'jpg', 'jpeg'], | |
| accept_multiple_files=True, key="train_style_adain") | |
| if style_imgs: | |
| st.markdown(f"**{len(style_imgs)} style image(s) uploaded**") | |
| # Display style images in a grid | |
| style_cols = st.columns(min(len(style_imgs), 3)) | |
| for idx, style_img in enumerate(style_imgs[:3]): | |
| with style_cols[idx % 3]: | |
| img = Image.open(style_img).convert('RGB') | |
| st.image(img, caption=f"Style {idx+1}", use_column_width=True) | |
| if len(style_imgs) > 3: | |
| st.caption(f"... and {len(style_imgs) - 3} more") | |
| with col2: | |
| st.subheader("Content Images") | |
| content_imgs = st.file_uploader("Upload content images (10-50 recommended)", | |
| type=['png', 'jpg', 'jpeg'], | |
| accept_multiple_files=True, | |
| key="train_content_adain") | |
| if content_imgs: | |
| st.markdown(f"**{len(content_imgs)} content image(s) uploaded**") | |
| # Store content images in session state for later use | |
| st.session_state.content_images_list = content_imgs | |
| # Display content images in a grid | |
| content_cols = st.columns(min(len(content_imgs), 3)) | |
| for idx, content_img in enumerate(content_imgs[:3]): | |
| with content_cols[idx % 3]: | |
| img = Image.open(content_img).convert('RGB') | |
| st.image(img, caption=f"Content {idx+1}", use_column_width=True) | |
| if len(content_imgs) > 3: | |
| st.caption(f"... and {len(content_imgs) - 3} more") | |
| with col3: | |
| st.subheader("Training Settings") | |
| model_name = st.text_input("Model Name", | |
| value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") | |
| # IMPROVED DEFAULT VALUES | |
| epochs = st.slider("Training Epochs", 10, 100, 50, 5) # Increased default | |
| batch_size = st.slider("Batch Size", 1, 8, 4) | |
| learning_rate = st.number_input("Learning Rate", 0.00001, 0.001, 0.0001, format="%.5f") | |
| with st.expander("Advanced Settings"): | |
| # MUCH HIGHER STYLE WEIGHT BY DEFAULT | |
| style_weight = st.number_input("Style Weight", 1.0, 1000.0, 100.0, 10.0) | |
| content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1) | |
| save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20, 10, 5) | |
| st.info("💡 **Pro tip**: For better quality, use Style Weight 100-500x higher than Content Weight") | |
| st.markdown("---") | |
| # Training button | |
| if st.button("Start AdaIN Training", type="primary", use_container_width=True): | |
| if style_imgs and content_imgs: | |
| if len(content_imgs) < 10: | |
| st.warning("For best results, use at least 10 content images") | |
| with st.spinner("Training AdaIN model..."): | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| def progress_callback(p, msg): | |
| progress_bar.progress(p) | |
| status_text.text(msg) | |
| # Create temp directory for content images | |
| temp_content_dir = f'/tmp/content_images_{uuid.uuid4().hex}' | |
| os.makedirs(temp_content_dir, exist_ok=True) | |
| # Save content images | |
| for idx, img_file in enumerate(content_imgs): | |
| img = Image.open(img_file).convert('RGB') | |
| img.save(os.path.join(temp_content_dir, f'content_{idx}.jpg')) | |
| # Load style images | |
| style_images = [] | |
| for style_file in style_imgs: | |
| style_img = Image.open(style_file).convert('RGB') | |
| style_images.append(style_img) | |
| # IMPROVED TRAINING FUNCTION | |
| # Multi-layer VGG loss for better quality | |
| class MultiLayerVGG(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features | |
| self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1 | |
| self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1 | |
| self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1 | |
| self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1 | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| h1 = self.slice1(x) | |
| h2 = self.slice2(h1) | |
| h3 = self.slice3(h2) | |
| h4 = self.slice4(h3) | |
| return [h1, h2, h3, h4] | |
| # Create model | |
| model = AdaINStyleTransfer().to(system.device) | |
| optimizer = torch.optim.Adam(model.decoder.parameters(), lr=learning_rate) | |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8) | |
| print(f"Training AdaIN model at 512x512 resolution") | |
| print(f"Training device: {system.device}") | |
| # Prepare style images - LARGER SIZE | |
| style_transform = transforms.Compose([ | |
| transforms.Resize(600), # Increased size | |
| transforms.RandomCrop(512), # Larger crops | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| style_tensors = [] | |
| # Create multiple augmented versions | |
| for style_img in style_images: | |
| for _ in range(5): # 5 augmented versions per style | |
| style_tensor = style_transform(style_img).unsqueeze(0).to(system.device) | |
| style_tensors.append(style_tensor) | |
| # Prepare content dataset - LARGER SIZE | |
| content_transform = transforms.Compose([ | |
| transforms.Resize(600), | |
| transforms.RandomCrop(512), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| dataset = StyleTransferDataset(temp_content_dir, transform=content_transform) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) | |
| # Multi-layer loss network | |
| loss_network = MultiLayerVGG().to(system.device).eval() | |
| mse_loss = nn.MSELoss() | |
| # Training loop | |
| model.train() | |
| model.encoder.eval() | |
| total_steps = 0 | |
| # Multiply style weight for better results | |
| actual_style_weight = style_weight * 10 | |
| for epoch in range(epochs): | |
| epoch_loss = 0 | |
| epoch_content_loss = 0 | |
| epoch_style_loss = 0 | |
| for batch_idx, content_batch in enumerate(dataloader): | |
| content_batch = content_batch.to(system.device) | |
| # Randomly select style images | |
| batch_style = [] | |
| for _ in range(content_batch.size(0)): | |
| style_idx = np.random.randint(0, len(style_tensors)) | |
| batch_style.append(style_tensors[style_idx]) | |
| batch_style = torch.cat(batch_style, dim=0) | |
| # Forward pass | |
| output = model(content_batch, batch_style) | |
| # Multi-layer loss | |
| with torch.no_grad(): | |
| content_feats = loss_network(content_batch) | |
| style_feats = loss_network(batch_style) | |
| output_feats = loss_network(output) | |
| # Content loss from relu4_1 | |
| content_loss = mse_loss(output_feats[-1], content_feats[-1]) | |
| # Style loss from multiple layers | |
| style_loss = 0 | |
| style_weights = [0.2, 0.3, 0.5, 1.0] | |
| def gram_matrix(feat): | |
| b, c, h, w = feat.size() | |
| feat = feat.view(b, c, h * w) | |
| gram = torch.bmm(feat, feat.transpose(1, 2)) | |
| return gram / (c * h * w) | |
| for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)): | |
| output_gram = gram_matrix(output_feat) | |
| style_gram = gram_matrix(style_feat) | |
| style_loss += weight * mse_loss(output_gram, style_gram) | |
| style_loss /= len(style_weights) | |
| # Total loss | |
| loss = content_weight * content_loss + actual_style_weight * style_loss | |
| # Backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0) | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| epoch_content_loss += content_loss.item() | |
| epoch_style_loss += style_loss.item() | |
| total_steps += 1 | |
| # Progress callback | |
| if progress_callback and total_steps % 10 == 0: | |
| progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs | |
| progress_callback(progress, | |
| f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} " | |
| f"(C: {content_loss.item():.4f}, S: {style_loss.item():.4f})") | |
| # Step scheduler | |
| scheduler.step() | |
| # Print epoch stats | |
| avg_loss = epoch_loss / len(dataloader) | |
| print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, " | |
| f"Content={epoch_content_loss/len(dataloader):.4f}, " | |
| f"Style={epoch_style_loss/len(dataloader):.4f}") | |
| # Save checkpoint | |
| if (epoch + 1) % save_interval == 0: | |
| checkpoint_path = f'{system.models_dir}/{model_name}_epoch_{epoch+1}.pth' | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'loss': avg_loss, | |
| 'model_type': 'adain' | |
| }, checkpoint_path) | |
| print(f"Saved checkpoint: {checkpoint_path}") | |
| # Save final model | |
| final_path = f'{system.models_dir}/{model_name}_final.pth' | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'model_type': 'adain' | |
| }, final_path) | |
| # Cleanup | |
| shutil.rmtree(temp_content_dir) | |
| if model: | |
| st.session_state['trained_adain_model'] = model | |
| st.session_state['trained_style_images'] = style_images | |
| st.session_state['model_path'] = final_path | |
| st.success("AdaIN training complete! 🎉") | |
| # Add to system's models | |
| system.lightweight_models[model_name] = model | |
| progress_bar.empty() | |
| status_text.empty() | |
| else: | |
| st.error("Please upload both style and content images") | |
| # Testing section with regional application | |
| if 'trained_adain_model' in st.session_state: | |
| st.markdown("---") | |
| st.header("Test Your AdaIN Model") | |
| # Application mode selection | |
| application_mode = st.radio("Application Mode", | |
| ["Whole Image", "Paint Region"], | |
| horizontal=True, | |
| help="Choose whether to apply style to entire image or paint specific regions") | |
| test_col1, test_col2, test_col3 = st.columns([1, 1, 1]) | |
| with test_col1: | |
| st.subheader("Test Options") | |
| # Test image selection | |
| test_source = st.radio("Test Image Source", | |
| ["Use Content Image", "Upload New", "Use Unsplash Image"], | |
| horizontal=True) | |
| test_image = None | |
| if test_source == "Use Content Image" and st.session_state.content_images_list: | |
| # Select from uploaded content images | |
| content_idx = st.selectbox("Select content image", | |
| range(len(st.session_state.content_images_list)), | |
| format_func=lambda x: f"Content Image {x+1}") | |
| test_image = Image.open(st.session_state.content_images_list[content_idx]).convert('RGB') | |
| elif test_source == "Use Unsplash Image": | |
| # Use current Unsplash image if available | |
| if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): | |
| test_image = st.session_state['current_image'] | |
| st.success("Using Unsplash image") | |
| else: | |
| st.info("Please search and select an image from the Style Transfer tab first") | |
| else: | |
| # Upload new image | |
| test_upload = st.file_uploader("Upload test image", | |
| type=['png', 'jpg', 'jpeg'], | |
| key="test_adain") | |
| if test_upload: | |
| test_image = Image.open(test_upload).convert('RGB') | |
| # Store test image in session state | |
| if test_image: | |
| st.session_state['adain_test_image'] = test_image | |
| # Style selection for testing | |
| if 'trained_style_images' in st.session_state and len(st.session_state['trained_style_images']) > 1: | |
| style_idx = st.selectbox("Select style", | |
| range(len(st.session_state['trained_style_images'])), | |
| format_func=lambda x: f"Style {x+1}") | |
| test_style = st.session_state['trained_style_images'][style_idx] | |
| elif 'trained_style_images' in st.session_state: | |
| test_style = st.session_state['trained_style_images'][0] | |
| st.info("Using the single trained style") | |
| else: | |
| test_style = None | |
| # IMPROVED DEFAULTS | |
| # Alpha blending control | |
| alpha = st.slider("Style Strength (Alpha)", 0.0, 2.0, 1.2, 0.1, | |
| help="0 = original content, 1 = full style transfer, >1 = stronger style") | |
| # Add tiling option - DEFAULT TO TRUE | |
| use_tiling = st.checkbox("Use Tiled Processing", | |
| value=True, # Default to True | |
| help="Process images in tiles for better quality. Recommended for ALL images.") | |
| # Initialize variables with default values | |
| brush_size = 30 | |
| drawing_mode = "freedraw" | |
| feather_radius = 10 | |
| # Regional painting options (only show if in paint mode) | |
| if application_mode == "Paint Region": | |
| st.markdown("---") | |
| st.subheader("Painting Options") | |
| brush_size = st.slider("Brush Size", 5, 100, 30) | |
| drawing_mode = st.selectbox("Drawing Tool", | |
| ["freedraw", "line", "rect", "circle", "polygon"], | |
| index=0) | |
| # Feather/blur the mask edges | |
| feather_radius = st.slider("Edge Softness", 0, 50, 10, | |
| help="Blur mask edges for smoother transitions") | |
| col_btn1, col_btn2 = st.columns(2) | |
| with col_btn1: | |
| if st.button("Clear Canvas", use_container_width=True): | |
| st.session_state['adain_canvas_result'] = None | |
| st.rerun() | |
| with col_btn2: | |
| if st.button("Reset Result", use_container_width=True): | |
| if 'adain_styled_result' in st.session_state: | |
| del st.session_state['adain_styled_result'] | |
| st.rerun() | |
| with test_col2: | |
| st.subheader("Canvas / Original") | |
| if application_mode == "Paint Region" and test_image: | |
| # Show canvas for painting | |
| display_img = resize_image_for_display(test_image, max_width=400, max_height=400) | |
| canvas_width, canvas_height = display_img.size | |
| st.info("Paint the areas where you want to apply the style") | |
| # Canvas for painting mask | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 0, 0, 0.3)", # Red with transparency | |
| stroke_width=brush_size, | |
| stroke_color="rgba(255, 0, 0, 0.5)", | |
| background_image=display_img, | |
| update_streamlit=True, | |
| height=canvas_height, | |
| width=canvas_width, | |
| drawing_mode=drawing_mode, | |
| display_toolbar=True, | |
| key=f"adain_canvas_{brush_size}_{drawing_mode}" | |
| ) | |
| # Save canvas result | |
| if canvas_result: | |
| st.session_state['adain_canvas_result'] = canvas_result | |
| # Show style image below canvas | |
| if test_style: | |
| st.markdown("---") | |
| st.image(test_style, caption="Style Image", use_column_width=True) | |
| else: | |
| # Show original images | |
| if test_image: | |
| st.image(test_image, caption="Content Image", use_column_width=True) | |
| if test_style: | |
| st.image(test_style, caption="Style Image", use_column_width=True) | |
| with test_col3: | |
| st.subheader("Result") | |
| # Apply button | |
| apply_button = st.button("Apply Style", type="primary", use_container_width=True) | |
| if apply_button and test_image and test_style: | |
| with st.spinner("Applying style..."): | |
| if application_mode == "Whole Image": | |
| # Apply to whole image | |
| result = system.apply_adain_style( | |
| test_image, | |
| test_style, | |
| st.session_state['trained_adain_model'], | |
| alpha=alpha, | |
| use_tiling=use_tiling | |
| ) | |
| else: | |
| # Apply to painted region | |
| result = apply_adain_regional( | |
| test_image, | |
| test_style, | |
| st.session_state['trained_adain_model'], | |
| st.session_state.get('adain_canvas_result'), | |
| alpha=alpha, | |
| feather_radius=feather_radius, | |
| use_tiling=use_tiling | |
| ) | |
| if result: | |
| st.session_state['adain_styled_result'] = result | |
| # Show result if available | |
| if 'adain_styled_result' in st.session_state: | |
| st.image(st.session_state['adain_styled_result'], | |
| caption="Styled Result", | |
| use_column_width=True) | |
| # Quality tips | |
| with st.expander("💡 Tips for Better Quality"): | |
| st.markdown(""" | |
| - **Always use tiling** for best quality | |
| - Try **alpha > 1.0** (1.2-1.5) for stronger style | |
| - Use **multiple style images** when training | |
| - Train for **50+ epochs** for best results | |
| - If quality is still poor, retrain with **style weight = 200-500** | |
| """) | |
| # Download button | |
| buf = io.BytesIO() | |
| st.session_state['adain_styled_result'].save(buf, format='PNG') | |
| st.download_button( | |
| label="Download Result", | |
| data=buf.getvalue(), | |
| file_name=f"adain_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png", | |
| mime="image/png" | |
| ) | |
| # Model download section | |
| st.markdown("---") | |
| if 'model_path' in st.session_state and os.path.exists(st.session_state['model_path']): | |
| col_dl1, col_dl2 = st.columns(2) | |
| with col_dl1: | |
| with open(st.session_state['model_path'], 'rb') as f: | |
| st.download_button( | |
| label="Download Trained AdaIN Model", | |
| data=f.read(), | |
| file_name=f"{model_name}_final.pth", | |
| mime="application/octet-stream", | |
| use_container_width=True | |
| ) | |
| with col_dl2: | |
| st.info("This model can be loaded and used for real-time style transfer") | |
| # Add this helper function (place it before the tab or with other helper functions) | |
| def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False): | |
| """Apply AdaIN style transfer to a painted region only""" | |
| if content_image is None or style_image is None or model is None: | |
| return None | |
| try: | |
| # Get the mask from canvas | |
| if canvas_result is None or canvas_result.image_data is None: | |
| # No mask painted, apply to whole image | |
| return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) | |
| # Extract mask from canvas | |
| mask_data = canvas_result.image_data[:, :, 3] # Alpha channel | |
| mask = mask_data > 0 | |
| # Resize mask to match original image size | |
| original_size = content_image.size | |
| display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0]) | |
| if original_size != display_size: | |
| # Convert mask to PIL image for resizing | |
| mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L') | |
| mask_pil = mask_pil.resize(original_size, Image.NEAREST) | |
| mask = np.array(mask_pil) > 128 | |
| # Apply feathering to mask edges if requested | |
| if feather_radius > 0: | |
| from scipy.ndimage import gaussian_filter | |
| mask_float = mask.astype(np.float32) | |
| mask_float = gaussian_filter(mask_float, sigma=feather_radius) | |
| mask_float = np.clip(mask_float, 0, 1) | |
| else: | |
| mask_float = mask.astype(np.float32) | |
| # Apply style to entire image with tiling option | |
| styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling) | |
| if styled_full is None: | |
| return None | |
| # Blend original and styled based on mask | |
| original_array = np.array(content_image, dtype=np.float32) | |
| styled_array = np.array(styled_full, dtype=np.float32) | |
| # Expand mask to 3 channels | |
| mask_3ch = np.stack([mask_float] * 3, axis=2) | |
| # Blend | |
| result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch | |
| result_array = np.clip(result_array, 0, 255).astype(np.uint8) | |
| return Image.fromarray(result_array) | |
| except Exception as e: | |
| print(f"Error applying regional AdaIN style: {e}") | |
| traceback.print_exc() | |
| return None | |
| # TAB 5: Batch Processing | |
| with tab5: | |
| st.header("Batch Processing") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Image source selection for batch | |
| batch_source = st.radio("Image Source", ["Upload Multiple", "Use Current Unsplash Image"], horizontal=True, key="batch_source") | |
| batch_files = [] | |
| if batch_source == "Upload Multiple": | |
| batch_files = st.file_uploader("Upload Images", type=['png', 'jpg', 'jpeg'], | |
| accept_multiple_files=True, key="batch_upload") | |
| else: | |
| # Use current Unsplash image if available | |
| if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'): | |
| batch_files = [st.session_state['current_image']] | |
| st.success("Using current Unsplash image for batch processing") | |
| else: | |
| st.info("Please search and select an image from the Style Transfer tab first") | |
| processing_type = st.radio("Processing Type", ["CycleGAN", "Custom Trained Model"]) | |
| if processing_type == "CycleGAN": | |
| # Style configuration | |
| batch_styles = [] | |
| for i in range(3): | |
| with st.expander(f"Style {i+1}", expanded=(i==0)): | |
| style = st.selectbox(f"Select style", style_choices, key=f"batch_style_{i}") | |
| intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"batch_intensity_{i}") | |
| if style and intensity > 0: | |
| model_key = None | |
| for key, info in system.cyclegan_models.items(): | |
| if info['name'] == style: | |
| model_key = key | |
| break | |
| if model_key: | |
| batch_styles.append(('cyclegan', model_key, intensity)) | |
| batch_blend_mode = st.selectbox("Blend Mode", | |
| ["additive", "average", "maximum", "overlay", "screen"], | |
| index=0, key="batch_blend") | |
| else: | |
| # Custom model upload | |
| custom_model_file = st.file_uploader("Upload Trained Model (.pth)", type=['pth']) | |
| if st.button("Process Batch", type="primary", use_container_width=True): | |
| if batch_files: | |
| with st.spinner("Processing batch..."): | |
| progress_bar = st.progress(0) | |
| processed_images = [] | |
| if processing_type == "CycleGAN" and batch_styles: | |
| for idx, file in enumerate(batch_files): | |
| progress_bar.progress((idx + 1) / len(batch_files)) | |
| # Handle both file uploads and PIL images | |
| if isinstance(file, Image.Image): | |
| image = file | |
| else: | |
| image = Image.open(file).convert('RGB') | |
| result = system.blend_styles(image, batch_styles, batch_blend_mode) | |
| processed_images.append(result) | |
| elif processing_type == "Custom Trained Model" and custom_model_file: | |
| # Load custom model | |
| temp_model = tempfile.NamedTemporaryFile(delete=False, suffix='.pth') | |
| temp_model.write(custom_model_file.read()) | |
| temp_model.close() | |
| model = system.load_lightweight_model(temp_model.name) | |
| if model: | |
| for idx, file in enumerate(batch_files): | |
| progress_bar.progress((idx + 1) / len(batch_files)) | |
| # Handle both file uploads and PIL images | |
| if isinstance(file, Image.Image): | |
| image = file | |
| else: | |
| image = Image.open(file).convert('RGB') | |
| result = system.apply_lightweight_style(image, model) | |
| if result: | |
| processed_images.append(result) | |
| os.unlink(temp_model.name) | |
| if processed_images: | |
| # Create zip | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, 'w') as zf: | |
| for idx, img in enumerate(processed_images): | |
| img_buffer = io.BytesIO() | |
| img.save(img_buffer, format='PNG') | |
| zf.writestr(f"styled_{idx+1:03d}.png", img_buffer.getvalue()) | |
| st.session_state['batch_results'] = processed_images | |
| st.session_state['batch_zip'] = zip_buffer.getvalue() | |
| progress_bar.empty() | |
| with col2: | |
| if 'batch_results' in st.session_state: | |
| st.header("Results") | |
| # Show gallery | |
| cols = st.columns(4) | |
| for idx, img in enumerate(st.session_state['batch_results'][:8]): | |
| cols[idx % 4].image(img, use_column_width=True) | |
| if len(st.session_state['batch_results']) > 8: | |
| st.info(f"Showing 8 of {len(st.session_state['batch_results'])} processed images") | |
| # Download zip | |
| st.download_button( | |
| label="Download All (ZIP)", | |
| data=st.session_state['batch_zip'], | |
| file_name=f"batch_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.zip", | |
| mime="application/zip" | |
| ) | |
| # TAB 6: Documentation | |
| with tab6: | |
| st.markdown(f""" | |
| ## Style Transfer System Documentation | |
| ### Available CycleGAN Models | |
| This system includes pre-trained bidirectional CycleGAN models: | |
| {chr(10).join([f'- **{info["name"]}**' for key, info in sorted(system.cyclegan_models.items(), key=lambda item: item[1]["name"])])} | |
| ### Features | |
| #### Style Transfer | |
| - Apply multiple styles simultaneously | |
| - Adjustable intensity for each style | |
| - Multiple blending modes for creative effects | |
| - **NEW**: Search and use images from Unsplash | |
| #### Regional Transform | |
| - Paint specific regions to apply different styles | |
| - Support for multiple regions with different styles | |
| - Adjustable brush size and drawing tools | |
| - Base style + regional overlays | |
| - Persistent brush strokes across regions | |
| - Optimized display for large images | |
| #### Video Processing | |
| - Frame-by-frame style transfer | |
| - Maintains temporal consistency | |
| - Supports all style combinations and blend modes | |
| - Enhanced codec compatibility | |
| #### Custom Training | |
| - Train on any artistic style with minimal data (1-50 images) | |
| - Automatic data augmentation for small datasets | |
| - Adjustable model complexity (3-12 residual blocks) | |
| ### Model Architecture | |
| - **CycleGAN models**: 9-12 residual blocks for high-quality transformations | |
| - **Lightweight models**: 3-12 residual blocks (customizable during training) | |
| - **Training approach**: Unpaired image-to-image translation | |
| ### Technical Details | |
| - **Framework**: PyTorch | |
| - **GPU Support**: CUDA acceleration when available | |
| - **Image Formats**: JPG, PNG, BMP | |
| - **Video Formats**: MP4, AVI, MOV | |
| - **Model Size**: ~45MB (CycleGAN), 5-15MB (Lightweight) | |
| ### Unsplash Integration | |
| To use Unsplash image search: | |
| 1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers) | |
| 2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY` | |
| 3. Search for images directly in the app | |
| 4. Automatic attribution for photographers | |
| ### Usage Tips | |
| 1. **For best results**: Use high-quality input images | |
| 2. **Style intensity**: Start with 1.0, adjust to taste | |
| 3. **Blending modes**: | |
| - 'Additive' for bold effects | |
| - 'Average' for subtle blends | |
| - 'Overlay' for dramatic contrasts | |
| 4. **Regional painting**: | |
| - Use larger brush for smooth transitions | |
| - Multiple thin layers work better than one thick layer | |
| - Previous regions remain visible as you paint new ones | |
| 5. **Custom training**: More diverse content images = better generalization | |
| 6. **Video processing**: Keep videos under 30 seconds for faster processing | |
| ### Regional Transform Guide | |
| The regional transform feature allows you to: | |
| 1. Define multiple regions by painting on the canvas | |
| 2. Assign different styles to each region | |
| 3. Control intensity per region | |
| 4. Apply an optional base style to the entire image | |
| 5. Blend regions using various modes | |
| **Tips for Regional Transform:** | |
| - Start with a base style for overall coherence | |
| - Use semi-transparent brushes for smoother transitions | |
| - Overlap regions for interesting blend effects | |
| - Experiment with different blend modes per region | |
| - All regions are visible while painting for better control | |
| """) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("Style transfer system with CycleGAN models and regional painting capabilities.") |