style-transfer / app.py
dannyroxas's picture
Update app.py
d731f27 verified
#!/usr/bin/env python
"""
STYLE TRANSFER APP - Streamlit Version with Regional Transformations
All existing features preserved + new local painting capabilities + Unsplash integration
"""
import os
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['HF_HOME'] = '/tmp/hf_cache'
os.makedirs('/tmp/torch_cache', exist_ok=True)
os.makedirs('/tmp/hf_cache', exist_ok=True)
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import glob
import datetime
import traceback
import uuid
import warnings
import zipfile
import io
import json
import time
import shutil
import requests
import scipy
try:
import cv2
VIDEO_PROCESSING_AVAILABLE = True
except ImportError:
VIDEO_PROCESSING_AVAILABLE = False
print("OpenCV not available - video processing disabled")
import tempfile
from pathlib import Path
import colorsys
warnings.filterwarnings("ignore")
# Set page config
st.set_page_config(
page_title="Style Transfer Studio",
# page_icon="",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better UI
st.markdown("""
<style>
.stTabs [data-baseweb="tab-list"] {
gap: 24px;
}
.stTabs [data-baseweb="tab"] {
height: 50px;
padding-left: 20px;
padding-right: 20px;
}
.main > div {
padding-top: 2rem;
}
.st-emotion-cache-1y4p8pa {
max-width: 100%;
}
/* Fix canvas container */
.stDrawableCanvas {
margin: 0 auto;
}
/* Unsplash grid styling */
.unsplash-grid img {
border-radius: 8px;
cursor: pointer;
transition: transform 0.2s;
}
.unsplash-grid img:hover {
transform: scale(1.05);
}
</style>
""", unsafe_allow_html=True)
# Force CUDA if available
if torch.cuda.is_available():
torch.cuda.set_device(0)
# Set CUDA to be deterministic for consistency
torch.backends.cudnn.benchmark = True
print("CUDA device set")
print(f"CUDA version: {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")
# GPU SETUP
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Current GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
# ===========================
# UNSPLASH API INTEGRATION
# ===========================
class UnsplashAPI:
"""Simple Unsplash API integration"""
def __init__(self, access_key=None):
# Use environment variable only - no secrets.toml warning
if access_key:
self.access_key = access_key
else:
self.access_key = os.environ.get("UNSPLASH_ACCESS_KEY")
self.base_url = "https://api.unsplash.com"
def search_photos(self, query, per_page=20, page=1, orientation=None):
"""Search photos on Unsplash"""
if not self.access_key:
return None, "No Unsplash API key configured"
headers = {"Authorization": f"Client-ID {self.access_key}"}
params = {
"query": query,
"per_page": per_page,
"page": page
}
if orientation:
params["orientation"] = orientation # "landscape", "portrait", "squarish"
try:
response = requests.get(
f"{self.base_url}/search/photos",
headers=headers,
params=params,
timeout=10
)
response.raise_for_status()
return response.json(), None
except requests.exceptions.RequestException as e:
return None, f"Error searching Unsplash: {str(e)}"
def get_random_photos(self, count=12, collections=None, query=None):
"""Get random photos from Unsplash"""
if not self.access_key:
return None, "No Unsplash API key configured"
headers = {"Authorization": f"Client-ID {self.access_key}"}
params = {"count": count}
if collections:
params["collections"] = collections
if query:
params["query"] = query
try:
response = requests.get(
f"{self.base_url}/photos/random",
headers=headers,
params=params,
timeout=10
)
response.raise_for_status()
return response.json(), None
except requests.exceptions.RequestException as e:
return None, f"Error getting random photos: {str(e)}"
def download_photo(self, photo_url, size="regular"):
"""Download photo from URL"""
try:
# Add fm=jpg&q=80 for consistent format and quality
if "?" in photo_url:
photo_url += "&fm=jpg&q=80"
else:
photo_url += "?fm=jpg&q=80"
response = requests.get(photo_url, timeout=30)
response.raise_for_status()
return Image.open(io.BytesIO(response.content)).convert('RGB')
except Exception as e:
st.error(f"Error downloading image: {str(e)}")
return None
def trigger_download(self, download_location):
"""Trigger download event (required by Unsplash API)"""
if not self.access_key or not download_location:
return
headers = {"Authorization": f"Client-ID {self.access_key}"}
try:
requests.get(download_location, headers=headers, timeout=5)
except:
pass # Don't fail if tracking fails
# ===========================
# MODEL ARCHITECTURES
# ===========================
class LightweightResidualBlock(nn.Module):
"""Lightweight residual block with depthwise separable convolutions"""
def __init__(self, channels):
super(LightweightResidualBlock, self).__init__()
self.depthwise = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3, groups=channels),
nn.InstanceNorm2d(channels, affine=True),
nn.ReLU(inplace=True)
)
self.pointwise = nn.Sequential(
nn.Conv2d(channels, channels, 1),
nn.InstanceNorm2d(channels, affine=True)
)
def forward(self, x):
return x + self.pointwise(self.depthwise(x))
class ResidualBlock(nn.Module):
"""Standard residual block for CycleGAN"""
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features, affine=True),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features, affine=True)
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=9):
super(Generator, self).__init__()
# Initial convolution block
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.InstanceNorm2d(64, affine=True),
nn.ReLU(inplace=True)
]
# Downsampling
in_features = 64
out_features = in_features * 2
for _ in range(2):
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features, affine=True),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features * 2
# Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)]
# Upsampling
out_features = in_features // 2
for _ in range(2):
model += [
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features, affine=True),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features // 2
# Output layer
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7),
nn.Tanh()
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class LightweightStyleNet(nn.Module):
"""Lightweight network for fast style transfer training"""
def __init__(self, n_residual_blocks=5):
super(LightweightStyleNet, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(3, 32, 9, stride=1),
nn.InstanceNorm2d(32, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.InstanceNorm2d(64, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.InstanceNorm2d(128, affine=True),
nn.ReLU(inplace=True)
)
# Residual blocks
res_blocks = []
for _ in range(n_residual_blocks):
res_blocks.append(LightweightResidualBlock(128))
self.res_blocks = nn.Sequential(*res_blocks)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64, affine=True),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(32, affine=True),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(3),
nn.Conv2d(32, 3, 9, stride=1),
nn.Tanh()
)
def forward(self, x):
h = self.encoder(x)
h = self.res_blocks(h)
h = self.decoder(h)
return h
class SimpleVGGFeatures(nn.Module):
"""Extract features from VGG19 for perceptual loss calculation"""
def __init__(self):
super(SimpleVGGFeatures, self).__init__()
try:
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
except:
vgg = models.vgg19(pretrained=True).features
self.features = nn.Sequential(*list(vgg.children())[:21])
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
return self.features(x)
# ===========================
# ADAIN ARCHITECTURE
# ===========================
class AdaIN(nn.Module):
"""Adaptive Instance Normalization layer"""
def __init__(self):
super(AdaIN, self).__init__()
def calc_mean_std(self, feat, eps=1e-5):
# Calculate mean and std for AdaIN
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def forward(self, content_feat, style_feat):
size = content_feat.size()
style_mean, style_std = self.calc_mean_std(style_feat)
content_mean, content_std = self.calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class VGGEncoder(nn.Module):
"""VGG-based encoder for AdaIN"""
def __init__(self):
super(VGGEncoder, self).__init__()
# Load pretrained VGG19
try:
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
except:
vgg = models.vgg19(pretrained=True).features
# Encoder uses layers up to relu4_1
self.enc1 = nn.Sequential(*list(vgg.children())[:2]) # conv1_1, relu1_1
self.enc2 = nn.Sequential(*list(vgg.children())[2:7]) # up to relu2_1
self.enc3 = nn.Sequential(*list(vgg.children())[7:12]) # up to relu3_1
self.enc4 = nn.Sequential(*list(vgg.children())[12:21]) # up to relu4_1
# Freeze encoder weights
for param in self.parameters():
param.requires_grad = False
def encode(self, x):
"""Get only the final features for AdaIN"""
h1 = self.enc1(x)
h2 = self.enc2(h1)
h3 = self.enc3(h2)
h4 = self.enc4(h3)
return h4
def forward(self, x):
h1 = self.enc1(x)
h2 = self.enc2(h1)
h3 = self.enc3(h2)
h4 = self.enc4(h3)
return h4, h3, h2, h1 # Return intermediate features for skip connections
class AdaINDecoder(nn.Module):
"""Decoder for AdaIN style transfer"""
def __init__(self):
super(AdaINDecoder, self).__init__()
# Decoder mirrors encoder but in reverse
self.dec4 = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
)
self.dec3 = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 128, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
)
self.dec2 = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 64, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
)
self.dec1 = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 3, (3, 3)),
)
def forward(self, x):
h = self.dec4(x)
h = self.dec3(h)
h = self.dec2(h)
h = self.dec1(h)
return h
class AdaINStyleTransfer(nn.Module):
"""Complete AdaIN style transfer network"""
def __init__(self):
super(AdaINStyleTransfer, self).__init__()
self.encoder = VGGEncoder()
self.decoder = AdaINDecoder()
self.adain = AdaIN()
# Only decoder needs to be trained
self.encoder.eval()
for param in self.encoder.parameters():
param.requires_grad = False
def encode(self, x):
return self.encoder.encode(x) # Use the encode method
def forward(self, content, style, alpha=1.0):
# Encode content and style
content_feat = self.encode(content)
style_feat = self.encode(style)
# Apply AdaIN
feat = self.adain(content_feat, style_feat)
# Alpha blending in feature space
if alpha < 1.0:
feat = alpha * feat + (1 - alpha) * content_feat
# Decode
return self.decoder(feat)
# ===========================
# DATASET AND LOSS FUNCTIONS
# ===========================
class StyleTransferDataset(Dataset):
"""Dataset for training style transfer models with augmentation support"""
def __init__(self, content_dir, transform=None, augment_factor=1):
self.content_dir = Path(content_dir)
self.transform = transform
self.augment_factor = augment_factor
extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
self.images = []
for ext in extensions:
self.images.extend(list(self.content_dir.glob(ext)))
self.images.extend(list(self.content_dir.glob(ext.upper())))
print(f"Found {len(self.images)} content images")
self.augmented_images = self.images * self.augment_factor
if self.augment_factor > 1:
print(f"Dataset augmented {self.augment_factor}x to {len(self.augmented_images)} samples")
def __len__(self):
return len(self.augmented_images)
def __getitem__(self, idx):
img_path = self.augmented_images[idx % len(self.images)]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
class PerceptualLoss(nn.Module):
"""Perceptual loss using VGG features"""
def __init__(self, vgg_features):
super(PerceptualLoss, self).__init__()
self.vgg = vgg_features
self.mse = nn.MSELoss()
def gram_matrix(self, features):
b, c, h, w = features.size()
features = features.view(b, c, h * w)
gram = torch.bmm(features, features.transpose(1, 2))
return gram / (c * h * w)
def forward(self, generated, content, style, content_weight=1.0, style_weight=1e5):
gen_feat = self.vgg(generated)
content_feat = self.vgg(content)
style_feat = self.vgg(style)
content_loss = self.mse(gen_feat, content_feat)
gen_gram = self.gram_matrix(gen_feat)
style_gram = self.gram_matrix(style_feat)
style_loss = self.mse(gen_gram, style_gram)
total_loss = content_weight * content_loss + style_weight * style_loss
return total_loss, content_loss, style_loss
# ===========================
# VIDEO PROCESSING
# ===========================
class VideoProcessor:
"""Process videos frame by frame with style transfer"""
def __init__(self, system):
self.system = system
def process_video(self, video_path, style_configs, blend_mode, progress_callback=None):
"""Process a video file with style transfer"""
if not VIDEO_PROCESSING_AVAILABLE:
print("Video processing requires OpenCV (cv2) - please install it")
return None
try:
# Handle both string path and file object
if hasattr(video_path, 'name'):
video_path = video_path.name
# Open video
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Create temporary output file - always start with mp4
temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
temp_output.close()
# Try different codecs - prioritize MP4-compatible ones
codecs_to_try = [
('mp4v', '.mp4'), # MPEG-4 - most compatible MP4 codec
('MP4V', '.mp4'), # Alternative case
('FMP4', '.mp4'), # Another MPEG-4 variant
('DIVX', '.mp4'), # DivX (sometimes works with MP4)
('XVID', '.avi'), # If MP4 fails, try AVI
('MJPG', '.avi'), # Motion JPEG - fallback
('DIV3', '.avi'), # DivX3
(0, '.avi') # Uncompressed (last resort)
]
out = None
output_path = None
used_codec = None
for codec_str, ext in codecs_to_try:
try:
# Update output filename with appropriate extension
if ext == '.mp4':
output_path = temp_output.name
else:
output_path = temp_output.name.replace('.mp4', ext)
if codec_str == 0:
fourcc = 0
print("Using uncompressed video (larger file size)")
else:
fourcc = cv2.VideoWriter_fourcc(*codec_str)
print(f"Trying codec: {codec_str} for {ext}")
# Create writer with specific parameters
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True)
if out.isOpened():
# Test write a black frame to ensure it really works
test_frame = np.zeros((height, width, 3), dtype=np.uint8)
out.write(test_frame)
out.release()
# Re-open for actual writing
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True)
if out.isOpened():
used_codec = codec_str
print(f"✓ Successfully using codec: {codec_str} with {ext}")
break
out.release()
out = None
except Exception as e:
print(f"Failed with codec {codec_str}: {e}")
if out:
out.release()
out = None
continue
if out is None or not out.isOpened():
# Last resort: save frames as images and create video differently
print("Standard codecs failed. Trying alternative approach...")
return self._process_with_frame_saving(cap, style_configs, blend_mode, fps, width, height, total_frames, progress_callback)
# Process frames
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB and to PIL Image
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(rgb_frame)
# Apply style transfer
styled_frame = self.system.blend_styles(pil_frame, style_configs, blend_mode)
# Convert back to BGR for video
styled_array = np.array(styled_frame)
bgr_frame = cv2.cvtColor(styled_array, cv2.COLOR_RGB2BGR)
# Ensure frame is correct size and type
if bgr_frame.shape[:2] != (height, width):
bgr_frame = cv2.resize(bgr_frame, (width, height))
out.write(bgr_frame)
frame_count += 1
if progress_callback and frame_count % 10 == 0:
progress = frame_count / total_frames
progress_callback(progress, f"Processing frame {frame_count}/{total_frames}")
cap.release()
out.release()
# Verify the output file
if not os.path.exists(output_path) or os.path.getsize(output_path) < 1000:
print(f"Output file is too small or doesn't exist (size: {os.path.getsize(output_path) if os.path.exists(output_path) else 0} bytes)")
return None
print(f"Video successfully saved to: {output_path}")
print(f"File size: {os.path.getsize(output_path) / 1024 / 1024:.2f} MB")
print(f"Format: {os.path.splitext(output_path)[1]}")
print(f"Codec used: {used_codec}")
# Clean up original temp file if different
if output_path != temp_output.name and os.path.exists(temp_output.name):
try:
os.unlink(temp_output.name)
except:
pass
return output_path
except Exception as e:
print(f"Error processing video: {e}")
traceback.print_exc()
return None
def _process_with_frame_saving(self, cap, style_configs, blend_mode, fps, width, height, total_frames, progress_callback):
"""Alternative processing method: save frames then combine"""
try:
print("Using frame-saving fallback method...")
temp_dir = tempfile.mkdtemp()
frame_count = 0
frame_paths = []
# Process and save frames
while True:
ret, frame = cap.read()
if not ret:
break
# Process frame
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(rgb_frame)
styled_frame = self.system.blend_styles(pil_frame, style_configs, blend_mode)
# Save frame
frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png")
styled_frame.save(frame_path)
frame_paths.append(frame_path)
frame_count += 1
if progress_callback and frame_count % 10 == 0:
progress = frame_count / total_frames
progress_callback(progress, f"Processing frame {frame_count}/{total_frames}")
cap.release()
if not frame_paths:
return None
# Try to create video from frames
output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
# Read first frame to get size
first_frame = cv2.imread(frame_paths[0])
h, w = first_frame.shape[:2]
# Try simple mp4v codec
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
if out.isOpened():
for frame_path in frame_paths:
frame = cv2.imread(frame_path)
out.write(frame)
out.release()
# Clean up frames
shutil.rmtree(temp_dir)
if os.path.exists(output_path) and os.path.getsize(output_path) > 1000:
print(f"Successfully created video using frame-saving method")
return output_path
# Clean up
shutil.rmtree(temp_dir)
return None
except Exception as e:
print(f"Frame-saving method failed: {e}")
if 'temp_dir' in locals() and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
return None
# ===========================
# MAIN STYLE TRANSFER SYSTEM
# ===========================
class StyleTransferSystem:
def __init__(self):
self.device = device
self.cyclegan_models = {}
self.loaded_generators = {}
self.lightweight_models = {}
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.inverse_transform = transforms.Compose([
transforms.Normalize((-1, -1, -1), (2, 2, 2)),
transforms.ToPILImage()
])
self.vgg_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.discover_cyclegan_models()
self.models_dir = '/tmp/trained_models'
os.makedirs(self.models_dir, exist_ok=True)
if VIDEO_PROCESSING_AVAILABLE:
self.video_processor = VideoProcessor(self)
# Test GPU functionality
if self.device.type == 'cuda':
self._test_gpu()
def _test_gpu(self):
"""Test if GPU is working correctly"""
try:
print("\nTesting GPU functionality...")
with torch.no_grad():
# Create a small test tensor
test_tensor = torch.randn(1, 3, 256, 256).to(self.device)
print(f"Test tensor device: {test_tensor.device}")
# Perform a simple operation
result = test_tensor * 2.0
print(f"Result device: {result.device}")
# Check memory usage
print(f"GPU memory after test: {torch.cuda.memory_allocated() / 1e9:.3f} GB")
# Clean up
del test_tensor, result
torch.cuda.empty_cache()
print("✓ GPU test successful!\n")
except Exception as e:
print(f"✗ GPU test failed: {e}\n")
traceback.print_exc()
def discover_cyclegan_models(self):
"""Find all available CycleGAN models including both AB and BA directions"""
print("\nDiscovering CycleGAN models...")
# Updated patterns to match your directory structure
patterns = [
'./models/*_best_*/*generator_*.pth',
'./models/*_best_*/*.pth',
'./models/*/*generator*.pth',
'./models/*/*.pth'
]
all_files = set()
for pattern in patterns:
files = glob.glob(pattern)
if files:
print(f"Found in {pattern}: {len(files)} items")
all_files.update(files)
# Also check if models directory exists and list contents
if os.path.exists('./models'):
print(f"\nModels directory contents:")
for folder in os.listdir('./models'):
folder_path = os.path.join('./models', folder)
if os.path.isdir(folder_path):
print(f" {folder}/")
for file in os.listdir(folder_path):
print(f" - {file}")
if file.endswith('.pth'):
all_files.add(os.path.join(folder_path, file))
# Group files by base model name
model_files = {}
for path in all_files:
# Skip normal models
if 'normal' in path.lower():
continue
filename = os.path.basename(path)
folder_name = os.path.basename(os.path.dirname(path))
# Extract base name from folder name
if '_best_' in folder_name:
base_name = folder_name.split('_best_')[0]
else:
base_name = folder_name
if base_name not in model_files:
model_files[base_name] = {'AB': None, 'BA': None}
# Check filename for direction
if 'generator_AB' in filename or 'g_AB' in filename or 'G_AB' in filename:
model_files[base_name]['AB'] = path
elif 'generator_BA' in filename or 'g_BA' in filename or 'G_BA' in filename:
model_files[base_name]['BA'] = path
elif 'generator' in filename.lower() and not any(x in filename for x in ['AB', 'BA']):
# If no direction specified, assume it's AB
if model_files[base_name]['AB'] is None:
model_files[base_name]['AB'] = path
# Create display names for models
model_display_map = {
'photo_bokeh': ('Bokeh', 'Sharp'),
'photo_golden': ('Golden Hour', 'Normal Light'),
'photo_monet': ('Monet Style', 'Photo'),
'photo_seurat': ('Seurat Style', 'Photo'),
'day_night': ('Night', 'Day'),
'summer_winter': ('Winter', 'Summer'),
'foggy_clear': ('Clear', 'Foggy')
}
# Register available models
for base_name, files in model_files.items():
clean_name = base_name.lower().replace('-', '_')
if clean_name in model_display_map:
style_from, style_to = model_display_map[clean_name]
# Register AB direction if available
if files['AB']:
display_name = f"{style_to} to {style_from}"
model_key = f"{clean_name}_AB"
self.cyclegan_models[model_key] = {
'path': files['AB'],
'name': display_name,
'base_name': base_name,
'direction': 'AB'
}
print(f"Registered: {display_name} ({model_key}) -> {files['AB']}")
# Register BA direction if available
if files['BA']:
display_name = f"{style_from} to {style_to}"
model_key = f"{clean_name}_BA"
self.cyclegan_models[model_key] = {
'path': files['BA'],
'name': display_name,
'base_name': base_name,
'direction': 'BA'
}
print(f"Registered: {display_name} ({model_key}) -> {files['BA']}")
if not self.cyclegan_models:
print("No CycleGAN models found!")
print("Make sure your model files are in the ./models directory")
else:
print(f"\nFound {len(self.cyclegan_models)} CycleGAN models\n")
def detect_architecture(self, state_dict):
"""Detect the number of residual blocks in CycleGAN model"""
residual_keys = [k for k in state_dict.keys() if 'model.' in k and '.block.' in k]
if not residual_keys:
return 9
block_indices = set()
for key in residual_keys:
parts = key.split('.')
for i in range(len(parts) - 1):
if parts[i] == 'model' and parts[i+1].isdigit():
block_indices.add(int(parts[i+1]))
break
n_blocks = len(block_indices)
return n_blocks if n_blocks > 0 else 9
def load_cyclegan_model(self, model_key):
"""Load a CycleGAN model"""
if model_key in self.loaded_generators:
# Ensure cached model is on the correct device
model = self.loaded_generators[model_key]
model = model.to(self.device)
return model
if model_key not in self.cyclegan_models:
print(f"Model {model_key} not found!")
return None
model_info = self.cyclegan_models[model_key]
try:
print(f"Loading {model_info['name']} from {model_info['path']}...")
print(f"Target device: {self.device}")
# Load with explicit map_location to ensure it goes to the right device
state_dict = torch.load(model_info['path'], map_location=self.device)
if 'generator' in state_dict:
state_dict = state_dict['generator']
n_blocks = self.detect_architecture(state_dict)
print(f"Detected {n_blocks} residual blocks")
generator = Generator(n_residual_blocks=n_blocks)
try:
generator.load_state_dict(state_dict, strict=True)
print(f"Loaded with strict=True")
except:
generator.load_state_dict(state_dict, strict=False)
print(f"Loaded with strict=False")
# Move to device BEFORE any precision changes
generator = generator.to(self.device)
generator.eval()
# Check if model is actually on GPU
print(f"Model device after .to(): {next(generator.parameters()).device}")
# Skip half precision to ensure GPU usage - half precision can cause issues
if self.device.type == 'cuda':
print(f"Using full precision (fp32) on GPU")
# Test GPU usage
with torch.no_grad():
test_input = torch.randn(1, 3, 256, 256).to(self.device)
_ = generator(test_input)
print(f"GPU test successful")
torch.cuda.empty_cache()
self.loaded_generators[model_key] = generator
print(f"Successfully loaded {model_info['name']} on {self.device}")
return generator
except Exception as e:
print(f"Failed to load {model_info['name']}: {e}")
traceback.print_exc()
return None
def apply_cyclegan_style(self, image, model_key, intensity=1.0):
"""Apply a CycleGAN style to an image"""
if image is None or model_key not in self.cyclegan_models:
return None
model_info = self.cyclegan_models[model_key]
generator = self.load_cyclegan_model(model_key)
if generator is None:
print(f"Could not load model for {model_info['name']}")
return None
try:
# Ensure model is on GPU
generator = generator.to(self.device)
# Debug GPU usage
if self.device.type == 'cuda':
print(f"GPU Memory before style transfer: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
original_size = image.size
w, h = image.size
new_w = ((w + 31) // 32) * 32
new_h = ((h + 31) // 32) * 32
max_size = 1024 if self.device.type == 'cuda' else 512
if new_w > max_size or new_h > max_size:
ratio = min(max_size / new_w, max_size / new_h)
new_w = int(new_w * ratio)
new_h = int(new_h * ratio)
new_w = ((new_w + 31) // 32) * 32
new_h = ((new_h + 31) // 32) * 32
image_resized = image.resize((new_w, new_h), Image.LANCZOS)
img_tensor = self.transform(image_resized).unsqueeze(0).to(self.device)
with torch.no_grad():
# Skip half precision for now to ensure GPU usage
if self.device.type == 'cuda':
torch.cuda.synchronize() # Ensure GPU operations complete
torch.cuda.empty_cache()
output = generator(img_tensor)
# Ensure output is on the same device
output = output.to(self.device)
if self.device.type == 'cuda':
torch.cuda.synchronize() # Wait for GPU to finish
output_img = self.inverse_transform(output.squeeze(0).cpu())
output_img = output_img.resize(original_size, Image.LANCZOS)
if self.device.type == 'cuda':
print(f"GPU Memory after style transfer: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
torch.cuda.empty_cache()
if intensity < 1.0:
output_array = np.array(output_img, dtype=np.float32)
original_array = np.array(image, dtype=np.float32)
blended = original_array * (1 - intensity) + output_array * intensity
output_img = Image.fromarray(blended.astype(np.uint8))
return output_img
except Exception as e:
print(f"Error applying style {model_info['name']}: {e}")
print(f"Device: {self.device}")
print(f"Model device: {next(generator.parameters()).device}")
traceback.print_exc()
return None
def train_lightweight_model(self, style_image, content_dir, model_name,
epochs=30, batch_size=4, lr=1e-3,
save_interval=5, style_weight=1e5, content_weight=1.0,
n_residual_blocks=5, progress_callback=None):
"""Train a lightweight style transfer model"""
model = LightweightStyleNet(n_residual_blocks=n_residual_blocks).to(self.device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print(f"Model architecture: {n_residual_blocks} residual blocks")
print(f"Training device: {self.device}")
# Verify model is on GPU
if self.device.type == 'cuda':
print(f"Model on GPU: {next(model.parameters()).device}")
print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
# Calculate augmentation factor
num_content_images = len(list(Path(content_dir).glob('*')))
if num_content_images < 5:
augment_factor = 20
elif num_content_images < 10:
augment_factor = 10
elif num_content_images < 20:
augment_factor = 5
else:
augment_factor = 1
# Create dataset with augmentation
if num_content_images < 10:
transform = transforms.Compose([
transforms.RandomResizedCrop(256, scale=(0.7, 1.2)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print(f"Using heavy augmentation due to limited images ({num_content_images} provided)")
else:
transform = transforms.Compose([
transforms.Resize(286),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = StyleTransferDataset(content_dir, transform=transform, augment_factor=augment_factor)
print(f"Training configuration:")
print(f" - Original images: {num_content_images}")
print(f" - Augmentation factor: {augment_factor}x")
print(f" - Total training samples: {len(dataset)}")
print(f" - Residual blocks: {n_residual_blocks}")
print(f" - Batch size: {int(batch_size)}")
print(f" - Epochs: {epochs}")
# Adjust batch size for small datasets
if num_content_images == 1:
if n_residual_blocks >= 9 and int(batch_size) > 1:
actual_batch_size = 1
print(f"Reduced batch size to 1 for single image + {n_residual_blocks} blocks")
elif int(batch_size) > 2:
actual_batch_size = 2
print(f"Reduced batch size to 2 for single image training")
else:
actual_batch_size = min(int(batch_size), len(dataset))
else:
actual_batch_size = min(int(batch_size), len(dataset))
dataloader = DataLoader(dataset, batch_size=actual_batch_size, shuffle=True,
num_workers=0 if num_content_images < 10 else 2)
# Prepare style image
style_transform = transforms.Compose([
transforms.Resize(800),
transforms.CenterCrop(768)
])
style_pil = style_transform(style_image)
style_tensor = self.vgg_transform(style_pil).unsqueeze(0).to(self.device)
# Create VGG features extractor for loss
vgg_features = SimpleVGGFeatures().to(self.device).eval()
print(f"VGG features on device: {next(vgg_features.parameters()).device}")
# Extract style features once
with torch.no_grad():
style_features = vgg_features(style_tensor)
# Loss function
perceptual_loss = PerceptualLoss(vgg_features)
# Training loop
model.train()
total_steps = 0
for epoch in range(epochs):
epoch_loss = 0
for batch_idx, content_batch in enumerate(dataloader):
content_batch = content_batch.to(self.device)
# Forward pass
output = model(content_batch)
# Ensure all tensors have the same size
target_size = (256, 256)
# Convert for VGG
output_vgg = []
content_vgg = []
for i in range(output.size(0)):
# Denormalize from [-1, 1] to [0, 1]
out_img = output[i] * 0.5 + 0.5
cont_img = content_batch[i] * 0.5 + 0.5
# Ensure exact size match
if out_img.shape[1:] != (target_size[0], target_size[1]):
out_img = F.interpolate(out_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0)
if cont_img.shape[1:] != (target_size[0], target_size[1]):
cont_img = F.interpolate(cont_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0)
# Normalize for VGG
out_norm = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)(out_img)
cont_norm = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)(cont_img)
output_vgg.append(out_norm)
content_vgg.append(cont_norm)
output_vgg = torch.stack(output_vgg)
content_vgg = torch.stack(content_vgg)
# Ensure style tensor matches batch size and dimensions
style_vgg = style_tensor.expand(output_vgg.size(0), -1, -1, -1)
if style_vgg.shape[2:] != output_vgg.shape[2:]:
style_vgg = F.interpolate(style_vgg, size=output_vgg.shape[2:], mode='bilinear', align_corners=False)
# Calculate loss
loss, content_loss, style_loss = perceptual_loss(
output_vgg, content_vgg, style_vgg,
content_weight=content_weight, style_weight=style_weight
)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
total_steps += 1
# Progress callback
if progress_callback and total_steps % 10 == 0:
progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
aug_info = f" (aug {num_content_images}{len(dataset)})" if num_content_images < 20 else ""
blocks_info = f", {n_residual_blocks} blocks"
progress_callback(progress, f"Epoch {epoch+1}/{epochs}{aug_info}{blocks_info}, Loss: {loss.item():.4f}")
# Save checkpoint
if (epoch + 1) % int(save_interval) == 0:
checkpoint_path = f'{self.models_dir}/{model_name}_epoch_{epoch+1}.pth'
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_loss / len(dataloader),
'n_residual_blocks': n_residual_blocks
}, checkpoint_path)
print(f"Saved checkpoint: {checkpoint_path}")
# Save final model
final_path = f'{self.models_dir}/{model_name}_final.pth'
torch.save({
'model_state_dict': model.state_dict(),
'n_residual_blocks': n_residual_blocks
}, final_path)
print(f"Training complete! Model saved to: {final_path}")
# Add to lightweight models
self.lightweight_models[model_name] = model
return model
def load_lightweight_model(self, model_path):
"""Load a trained lightweight model"""
try:
# Load directly to the target device
state_dict = torch.load(model_path, map_location=self.device)
# Check if n_residual_blocks is saved
if isinstance(state_dict, dict) and 'n_residual_blocks' in state_dict:
n_blocks = state_dict['n_residual_blocks']
print(f"Found saved architecture: {n_blocks} residual blocks")
else:
# Try to detect from state dict
if 'model_state_dict' in state_dict:
model_state = state_dict['model_state_dict']
else:
model_state = state_dict
res_block_keys = [k for k in model_state.keys() if 'res_blocks' in k and 'weight' in k]
n_blocks = len(set([k.split('.')[1] for k in res_block_keys if k.startswith('res_blocks')])) or 5
print(f"Detected {n_blocks} residual blocks from model structure")
# Create model with detected architecture
model = LightweightStyleNet(n_residual_blocks=n_blocks).to(self.device)
# Load the weights
if 'model_state_dict' in state_dict:
model.load_state_dict(state_dict['model_state_dict'])
else:
model.load_state_dict(state_dict)
model.eval()
# Verify model is on correct device
print(f"Lightweight model loaded on: {next(model.parameters()).device}")
return model
except Exception as e:
print(f"Error loading lightweight model: {e}")
# Try with default 5 blocks
try:
print("Attempting to load with default 5 residual blocks...")
model = LightweightStyleNet(n_residual_blocks=5).to(self.device)
if model_path.endswith('.pth'):
state_dict = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in state_dict:
model.load_state_dict(state_dict['model_state_dict'])
else:
model.load_state_dict(state_dict)
model.eval()
print(f"Fallback model loaded on: {next(model.parameters()).device}")
return model
except:
return None
# Inside the StyleTransferSystem class, add these methods:
def _create_linear_weight(self, width, height, overlap):
"""Create linear blending weights for tile edges"""
weight = np.ones((height, width, 1), dtype=np.float32)
if overlap > 0:
# Create gradients for each edge
for i in range(overlap):
alpha = i / overlap
# Top edge
weight[i, :] *= alpha
# Bottom edge
weight[-i-1, :] *= alpha
# Left edge
weight[:, i] *= alpha
# Right edge
weight[:, -i-1] *= alpha
return weight
def _create_gaussian_weight(self, width, height, overlap):
"""Create Gaussian blending weights for smoother transitions"""
weight = np.ones((height, width), dtype=np.float32)
# Create 2D Gaussian centered in the tile
y, x = np.ogrid[:height, :width]
center_y, center_x = height / 2, width / 2
# Distance from center
dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
# Gaussian falloff starting from the edges
max_dist = min(height, width) / 2
sigma = max_dist / 2 # Adjust for smoother/sharper transitions
# Apply Gaussian only near edges
edge_dist = np.minimum(
np.minimum(y, height - 1 - y),
np.minimum(x, width - 1 - x)
)
# Weight is 1 in center, Gaussian falloff near edges
weight = np.where(
edge_dist < overlap,
np.exp(-0.5 * ((overlap - edge_dist) / (overlap/3))**2),
1.0
)
return weight.reshape(height, width, 1)
def apply_lightweight_style(self, image, model, intensity=1.0):
"""Apply style using a lightweight model"""
if image is None or model is None:
return None
try:
# Ensure model is on the correct device
model = model.to(self.device)
model.eval()
original_size = image.size
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img_tensor = transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
if self.device.type == 'cuda':
torch.cuda.synchronize()
output = model(img_tensor)
if self.device.type == 'cuda':
torch.cuda.synchronize()
output_img = self.inverse_transform(output.squeeze(0).cpu())
output_img = output_img.resize(original_size, Image.LANCZOS)
if intensity < 1.0:
output_array = np.array(output_img, dtype=np.float32)
original_array = np.array(image, dtype=np.float32)
blended = original_array * (1 - intensity) + output_array * intensity
output_img = Image.fromarray(blended.astype(np.uint8))
return output_img
except Exception as e:
print(f"Error applying lightweight style: {e}")
print(f"Device: {self.device}")
print(f"Model device: {next(model.parameters()).device}")
return None
def blend_styles(self, image, style_configs, blend_mode="additive"):
"""Apply multiple styles with different blending modes"""
if not image or not style_configs:
return image
original = np.array(image, dtype=np.float32)
styled_images = []
weights = []
for style_type, model_key, intensity in style_configs:
if intensity <= 0:
continue
if style_type == 'cyclegan':
styled = self.apply_cyclegan_style(image, model_key, 1.0)
elif style_type == 'lightweight' and model_key in self.lightweight_models:
styled = self.apply_lightweight_style(image, self.lightweight_models[model_key], 1.0)
else:
continue
if styled:
styled_images.append(np.array(styled, dtype=np.float32))
weights.append(intensity)
if not styled_images:
return image
# Apply blending
if blend_mode == "average":
result = np.zeros_like(original)
total_weight = sum(weights)
for img, weight in zip(styled_images, weights):
result += img * (weight / total_weight)
elif blend_mode == "additive":
result = original.copy()
for img, weight in zip(styled_images, weights):
transformation = img - original
result = result + transformation * weight
elif blend_mode == "maximum":
result = original.copy()
for img, weight in zip(styled_images, weights):
transformation = (img - original) * weight
current_diff = result - original
mask = np.abs(transformation) > np.abs(current_diff)
result[mask] = original[mask] + transformation[mask]
elif blend_mode == "overlay":
result = original.copy()
for img, weight in zip(styled_images, weights):
overlay = np.zeros_like(result)
mask = result < 128
overlay[mask] = 2 * img[mask] * result[mask] / 255.0
overlay[~mask] = 255 - 2 * (255 - img[~mask]) * (255 - result[~mask]) / 255.0
result = result * (1 - weight) + overlay * weight
else: # "screen" mode
result = original.copy()
for img, weight in zip(styled_images, weights):
screened = 255 - ((255 - result) * (255 - img) / 255.0)
if weight > 1.0:
diff = screened - result
result = result + diff * weight
else:
result = result * (1 - weight) + screened * weight
return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
def apply_regional_styles(self, image, combined_mask, regions, base_style_configs=None, blend_mode="additive"):
"""Apply different styles to painted regions using a combined mask"""
if not regions:
if base_style_configs:
return self.blend_styles(image, base_style_configs, blend_mode)
return image
original_size = image.size
result = np.array(image, dtype=np.float32)
# Apply base style if provided
if base_style_configs:
base_styled = self.blend_styles(image, base_style_configs, blend_mode)
result = np.array(base_styled, dtype=np.float32)
# Resize mask to match original image if needed
if combined_mask is not None and combined_mask.shape[:2] != (original_size[1], original_size[0]):
# Resize the combined mask to match the original image
combined_mask_pil = Image.fromarray(combined_mask.astype(np.uint8))
combined_mask_resized = combined_mask_pil.resize(original_size, Image.NEAREST)
combined_mask = np.array(combined_mask_resized)
# Apply each region
for i, region in enumerate(regions):
if region['style'] is None:
continue
# Get model key for this region's style
model_key = None
for key, info in self.cyclegan_models.items():
if info['name'] == region['style']:
model_key = key
break
if not model_key:
continue
# Apply style to whole image
style_configs = [('cyclegan', model_key, region['intensity'])]
styled = self.blend_styles(image, style_configs, blend_mode)
styled_array = np.array(styled, dtype=np.float32)
# Create mask for this region from combined mask
if combined_mask is not None:
# Region masks are identified by their color index
region_mask = (combined_mask == (i + 1)).astype(np.float32)
# Ensure mask has same shape as image
if len(region_mask.shape) == 2:
region_mask_3ch = np.stack([region_mask] * 3, axis=2)
else:
region_mask_3ch = region_mask
# Blend using mask
result = result * (1 - region_mask_3ch) + styled_array * region_mask_3ch
return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
def train_adain_model(self, style_images, content_dir, model_name,
epochs=30, batch_size=4, lr=1e-4,
save_interval=5, style_weight=10.0, content_weight=1.0,
progress_callback=None):
"""Train an AdaIN-based style transfer model"""
model = AdaINStyleTransfer().to(self.device)
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
# Add learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
print(f"Training AdaIN model")
print(f"Training device: {self.device}")
# Verify model is on GPU
if self.device.type == 'cuda':
print(f"Model on GPU: {next(model.decoder.parameters()).device}")
print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
# Prepare style images - INCREASED SIZE
style_transform = transforms.Compose([
transforms.Resize(600), # Increased from 512
transforms.RandomCrop(512), # Increased from 256
transforms.RandomHorizontalFlip(p=0.5), # Add augmentation
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
style_tensors = []
# Create multiple augmented versions of each style image
for style_img in style_images:
# Generate 5 augmented versions per style image
for _ in range(5):
style_tensor = style_transform(style_img).unsqueeze(0).to(self.device)
style_tensors.append(style_tensor)
print(f"Created {len(style_tensors)} augmented style samples from {len(style_images)} images")
# Prepare content dataset - INCREASED SIZE
content_transform = transforms.Compose([
transforms.Resize(600), # Increased from 512
transforms.RandomCrop(512), # Increased from 256
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
dataset = StyleTransferDataset(content_dir, transform=content_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
print(f"Training configuration:")
print(f" - Style images: {len(style_tensors)}")
print(f" - Content images: {len(dataset)}")
print(f" - Batch size: {batch_size}")
print(f" - Epochs: {epochs}")
print(f" - Training resolution: 512x512") # Updated
# Loss network (VGG for perceptual loss) - USE MULTIPLE LAYERS
class MultiLayerVGG(nn.Module):
def __init__(self):
super().__init__()
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1
self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1
self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1
self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
h1 = self.slice1(x)
h2 = self.slice2(h1)
h3 = self.slice3(h2)
h4 = self.slice4(h3)
return [h1, h2, h3, h4]
loss_network = MultiLayerVGG().to(self.device).eval()
mse_loss = nn.MSELoss()
# Training loop
model.train()
model.encoder.eval() # Keep encoder frozen
total_steps = 0
# Adjust style weight for better quality
actual_style_weight = style_weight * 10 # Multiply by 10 for better style transfer
for epoch in range(epochs):
epoch_loss = 0
for batch_idx, content_batch in enumerate(dataloader):
content_batch = content_batch.to(self.device)
# Randomly select style images for this batch
batch_style = []
for _ in range(content_batch.size(0)):
style_idx = np.random.randint(0, len(style_tensors))
batch_style.append(style_tensors[style_idx])
batch_style = torch.cat(batch_style, dim=0)
# Forward pass
output = model(content_batch, batch_style)
# Multi-layer content and style loss
with torch.no_grad():
content_feats = loss_network(content_batch)
style_feats = loss_network(batch_style)
output_feats = loss_network(output)
# Content loss - only from relu4_1
content_loss = mse_loss(output_feats[-1], content_feats[-1])
# Style loss - from multiple layers
style_loss = 0
style_weights = [0.2, 0.3, 0.5, 1.0] # Give more weight to higher layers
def gram_matrix(feat):
b, c, h, w = feat.size()
feat = feat.view(b, c, h * w)
gram = torch.bmm(feat, feat.transpose(1, 2))
return gram / (c * h * w)
for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)):
output_gram = gram_matrix(output_feat)
style_gram = gram_matrix(style_feat)
style_loss += weight * mse_loss(output_gram, style_gram)
style_loss /= len(style_weights)
# Total loss
loss = content_weight * content_loss + actual_style_weight * style_loss
# Backward pass
optimizer.zero_grad()
loss.backward()
# Gradient clipping for stability
torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0)
optimizer.step()
epoch_loss += loss.item()
total_steps += 1
# Progress callback
if progress_callback and total_steps % 10 == 0:
progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
progress_callback(progress,
f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} "
f"(Content: {content_loss.item():.4f}, Style: {style_loss.item():.4f})")
# Step scheduler
scheduler.step()
# Save checkpoint
if (epoch + 1) % save_interval == 0:
checkpoint_path = f'{self.models_dir}/{model_name}_epoch_{epoch+1}.pth'
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': epoch_loss / len(dataloader),
'model_type': 'adain'
}, checkpoint_path)
print(f"Saved checkpoint: {checkpoint_path}")
# Save final model
final_path = f'{self.models_dir}/{model_name}_final.pth'
torch.save({
'model_state_dict': model.state_dict(),
'model_type': 'adain'
}, final_path)
print(f"Training complete! Model saved to: {final_path}")
# Add to lightweight models
self.lightweight_models[model_name] = model
return model
# Update these methods in your StyleTransferSystem class:
def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False):
"""Apply AdaIN-based style transfer with optional tiling"""
# Use tiling for large images to maintain quality
if use_tiling and (content_image.width > 768 or content_image.height > 768):
return self.apply_adain_style_tiled(
content_image, style_image, model, alpha,
tile_size=512, # Increased from 256
overlap=64, # Increased overlap
blend_mode='gaussian'
)
if content_image is None or style_image is None or model is None:
return None
try:
model = model.to(self.device)
model.eval()
original_size = content_image.size
# Use higher resolution - find optimal size while maintaining aspect ratio
max_dim = 768 # Increased from 256
w, h = content_image.size
if w > h:
new_w = min(w, max_dim)
new_h = int(h * new_w / w)
else:
new_h = min(h, max_dim)
new_w = int(w * new_h / h)
# Ensure dimensions are divisible by 8 for better compatibility
new_w = (new_w // 8) * 8
new_h = (new_h // 8) * 8
# Transform for AdaIN (VGG normalization)
transform = transforms.Compose([
transforms.Resize((new_h, new_w)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
content_tensor = transform(content_image).unsqueeze(0).to(self.device)
style_tensor = transform(style_image).unsqueeze(0).to(self.device)
with torch.no_grad():
output = model(content_tensor, style_tensor, alpha=alpha)
# Denormalize
output = output.squeeze(0).cpu()
output = output * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
output = output + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
output = torch.clamp(output, 0, 1)
# Convert to PIL
output_img = transforms.ToPILImage()(output)
output_img = output_img.resize(original_size, Image.LANCZOS)
return output_img
except Exception as e:
print(f"Error applying AdaIN style: {e}")
traceback.print_exc()
return None
def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0,
tile_size=256, overlap=32, blend_mode='linear'):
"""
Apply AdaIN style transfer using tiling for high-quality results.
Processes image in overlapping tiles to maintain quality.
"""
if content_image is None or style_image is None or model is None:
return None
try:
model = model.to(self.device)
model.eval()
# INCREASED TILE SIZE FOR BETTER QUALITY
tile_size = 512 # Override input to use 512
overlap = 64 # Increase overlap proportionally
# Prepare transforms
transform = transforms.Compose([
transforms.Resize((tile_size, tile_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Process style image once (at tile size)
style_tensor = transform(style_image).unsqueeze(0).to(self.device)
# Get dimensions
w, h = content_image.size
# Calculate tile positions with overlap
stride = tile_size - overlap
tiles_x = list(range(0, w - tile_size + 1, stride))
tiles_y = list(range(0, h - tile_size + 1, stride))
# Ensure we cover the entire image
if not tiles_x or tiles_x[-1] + tile_size < w:
tiles_x.append(max(0, w - tile_size))
if not tiles_y or tiles_y[-1] + tile_size < h:
tiles_y.append(max(0, h - tile_size))
# If image is smaller than tile size, just process normally
if w <= tile_size and h <= tile_size:
return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False)
print(f"Processing {len(tiles_x) * len(tiles_y)} tiles of size {tile_size}x{tile_size}")
# Initialize output and weight arrays
output_array = np.zeros((h, w, 3), dtype=np.float32)
weight_array = np.zeros((h, w, 1), dtype=np.float32)
# Process each tile
with torch.no_grad():
for y_idx, y in enumerate(tiles_y):
for x_idx, x in enumerate(tiles_x):
# Extract tile
tile = content_image.crop((x, y, x + tile_size, y + tile_size))
# Transform tile
tile_tensor = transform(tile).unsqueeze(0).to(self.device)
# Apply AdaIN to tile
styled_tensor = model(tile_tensor, style_tensor, alpha=alpha)
# Denormalize
styled_tensor = styled_tensor.squeeze(0).cpu()
denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
styled_tensor = styled_tensor * denorm_std + denorm_mean
styled_tensor = torch.clamp(styled_tensor, 0, 1)
# Convert to numpy
styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255
# Create weight mask for blending - use gaussian by default for better quality
weight = self._create_gaussian_weight(tile_size, tile_size, overlap)
# Add to output with weights
output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight
weight_array[y:y+tile_size, x:x+tile_size] += weight
# Normalize by weights
output_array = output_array / (weight_array + 1e-8)
output_array = np.clip(output_array, 0, 255).astype(np.uint8)
return Image.fromarray(output_array)
except Exception as e:
print(f"Error in tiled AdaIN processing: {e}")
traceback.print_exc()
# Fallback to standard processing
return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False)
# ===========================
# HELPER FUNCTIONS
# ===========================
def resize_image_for_display(image, max_width=800, max_height=600):
"""Resize image for display while maintaining aspect ratio"""
width, height = image.size
# Calculate scaling factor
width_scale = max_width / width
height_scale = max_height / height
scale = min(width_scale, height_scale)
# Only scale down, not up
if scale < 1:
new_width = int(width * scale)
new_height = int(height * scale)
return image.resize((new_width, new_height), Image.LANCZOS)
return image
def combine_region_masks(canvas_results, canvas_size):
"""Combine multiple region masks into a single mask with different values for each region"""
combined_mask = np.zeros(canvas_size[:2], dtype=np.uint8)
for i, canvas_data in enumerate(canvas_results):
if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None:
# Extract alpha channel as mask
mask = canvas_data.image_data[:, :, 3] > 0
# Assign region index (1-based) to mask
combined_mask[mask] = i + 1
return combined_mask
def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False):
"""Apply AdaIN style transfer to a painted region only"""
if content_image is None or style_image is None or model is None:
return None
try:
# Get the mask from canvas
if canvas_result is None or canvas_result.image_data is None:
# No mask painted, apply to whole image
return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
# Extract mask from canvas
mask_data = canvas_result.image_data[:, :, 3] # Alpha channel
mask = mask_data > 0
# Resize mask to match original image size
original_size = content_image.size
display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0])
if original_size != display_size:
# Convert mask to PIL image for resizing
mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
mask_pil = mask_pil.resize(original_size, Image.NEAREST)
mask = np.array(mask_pil) > 128
# Apply feathering to mask edges if requested
if feather_radius > 0:
from scipy.ndimage import gaussian_filter
mask_float = mask.astype(np.float32)
mask_float = gaussian_filter(mask_float, sigma=feather_radius)
mask_float = np.clip(mask_float, 0, 1)
else:
mask_float = mask.astype(np.float32)
# Apply style to entire image with tiling option
styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
if styled_full is None:
return None
# Blend original and styled based on mask
original_array = np.array(content_image, dtype=np.float32)
styled_array = np.array(styled_full, dtype=np.float32)
# Expand mask to 3 channels
mask_3ch = np.stack([mask_float] * 3, axis=2)
# Blend
result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch
result_array = np.clip(result_array, 0, 255).astype(np.uint8)
return Image.fromarray(result_array)
except Exception as e:
print(f"Error applying regional AdaIN style: {e}")
traceback.print_exc()
return None
# ===========================
# INITIALIZE SYSTEM AND API
# ===========================
@st.cache_resource
def load_system():
return StyleTransferSystem()
@st.cache_resource
def get_unsplash_api():
return UnsplashAPI()
system = load_system()
unsplash = get_unsplash_api()
# Get style choices
style_choices = sorted([info['name'] for info in system.cyclegan_models.values()])
# ===========================
# STREAMLIT APP
# ===========================
# Main app
st.title("Style Transfer")
st.markdown("Image and video style transfer with CycleGAN and custom training capabilities")
# Sidebar for global settings
with st.sidebar:
st.header("Settings")
# GPU status
if torch.cuda.is_available():
gpu_info = torch.cuda.get_device_properties(0)
st.success(f"GPU: {gpu_info.name}")
# Show memory usage
total_memory = gpu_info.total_memory / 1e9
used_memory = torch.cuda.memory_allocated() / 1e9
free_memory = total_memory - used_memory
col1, col2 = st.columns(2)
with col1:
st.metric("Total Memory", f"{total_memory:.2f} GB")
with col2:
st.metric("Used Memory", f"{used_memory:.2f} GB")
# Memory usage bar
memory_percentage = (used_memory / total_memory) * 100
st.progress(memory_percentage / 100)
st.caption(f"Free: {free_memory:.2f} GB ({100-memory_percentage:.1f}%)")
# Check if GPU is actually being used
if used_memory < 0.1:
st.warning("GPU detected but not in use. Models may be running on CPU.")
# Force GPU button
if st.button("Force GPU Reset"):
torch.cuda.empty_cache()
torch.cuda.synchronize()
st.rerun()
else:
st.warning("Running on CPU (GPU not available)")
st.caption("For faster processing, use a GPU-enabled environment")
st.markdown("---")
st.markdown("### Quick Guide")
st.markdown("""
- **Style Transfer**: Apply artistic styles to images
- **Regional Transform**: Paint areas for local effects
- **Video Processing**: Apply styles to videos
- **Train Custom**: Create your own style models
- **Batch Process**: Process multiple images
""")
# Unsplash API status
st.markdown("---")
if unsplash.access_key:
st.success("Unsplash API Connected")
else:
st.info("Add Unsplash API key for image search")
# Debug mode
st.markdown("---")
if st.checkbox("🐛 Debug Mode"):
st.code(f"""
Device: {device}
CUDA Available: {torch.cuda.is_available()}
CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}
PyTorch Version: {torch.__version__}
Models Loaded: {len(system.loaded_generators)}
""")
# Main tabs
tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
"Style Transfer",
"Regional Transform",
"Video Processing",
"Train Custom Style",
"Batch Processing",
"Documentation"
])
# TAB 1: Style Transfer (with Unsplash integration)
with tab1:
# Unsplash Search Section
with st.expander("Search Unsplash for Images", expanded=False):
if not unsplash.access_key:
st.info("""
To enable Unsplash search:
1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers)
2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY`
""")
else:
search_col1, search_col2, search_col3 = st.columns([3, 1, 1])
with search_col1:
search_query = st.text_input("Search for images", placeholder="e.g., landscape, portrait, abstract art")
with search_col2:
orientation = st.selectbox("Orientation", ["all", "landscape", "portrait", "squarish"])
with search_col3:
search_button = st.button("Search", use_container_width=True)
# Random photos button
if st.button("Get Random Photos"):
with st.spinner("Loading random photos..."):
results, error = unsplash.get_random_photos(count=12)
if error:
st.error(f"Error: {error}")
elif results:
# Handle both single photo and array of photos
photos = results if isinstance(results, list) else [results]
st.session_state['unsplash_results'] = photos
st.success(f"Loaded {len(photos)} random photos")
# Search functionality
if search_button and search_query:
with st.spinner(f"Searching for '{search_query}'..."):
orientation_param = None if orientation == "all" else orientation
results, error = unsplash.search_photos(search_query, per_page=12, orientation=orientation_param)
if error:
st.error(f"Error: {error}")
elif results and results.get('results'):
st.session_state['unsplash_results'] = results['results']
st.success(f"Found {results['total']} images")
else:
st.info("No images found. Try a different search term.")
# Display results
if 'unsplash_results' in st.session_state and st.session_state['unsplash_results']:
st.markdown("### Search Results")
# Display in a 4-column grid
cols = st.columns(4)
for idx, photo in enumerate(st.session_state['unsplash_results'][:12]):
with cols[idx % 4]:
# Show thumbnail
st.image(photo['urls']['thumb'], use_column_width=True)
# Photo info
st.caption(f"By {photo['user']['name']}")
# Use button
if st.button("Use This", key=f"use_unsplash_{photo['id']}"):
with st.spinner("Loading image..."):
# Download regular size
img = unsplash.download_photo(photo['urls']['regular'])
if img:
# Store in session state
st.session_state['current_image'] = img
st.session_state['image_source'] = f"Unsplash: {photo['user']['name']}"
st.session_state['unsplash_photo'] = photo
# Trigger download tracking (required by Unsplash)
if 'links' in photo and 'download_location' in photo['links']:
unsplash.trigger_download(photo['links']['download_location'])
st.success("Image loaded!")
st.rerun()
col1, col2 = st.columns(2)
with col1:
st.header("Input")
# Image source selection
image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True)
# Initialize input_image to None
input_image = None
if image_source == "Upload":
uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'])
if uploaded_file:
input_image = Image.open(uploaded_file).convert('RGB')
st.session_state['current_image'] = input_image
st.session_state['image_source'] = "Uploaded"
else:
# Handle Unsplash selection
if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
input_image = st.session_state['current_image']
else:
st.info("Search for an image above")
if input_image:
# Display the image
display_img = resize_image_for_display(input_image, max_width=600, max_height=400)
st.image(display_img, caption=st.session_state.get('image_source', 'Image'), use_column_width=True)
# Attribution for Unsplash images
if 'unsplash_photo' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
photo = st.session_state['unsplash_photo']
st.markdown(f"Photo by [{photo['user']['name']}]({photo['user']['links']['html']}) on [Unsplash]({photo['links']['html']})")
st.subheader("Style Configuration")
# Up to 3 styles
num_styles = st.number_input("Number of styles to apply", 1, 3, 1)
style_configs = []
for i in range(num_styles):
with st.expander(f"Style {i+1}", expanded=(i==0)):
style = st.selectbox(f"Select style", style_choices, key=f"style_{i}")
intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"intensity_{i}")
if style and intensity > 0:
model_key = None
for key, info in system.cyclegan_models.items():
if info['name'] == style:
model_key = key
break
if model_key:
style_configs.append(('cyclegan', model_key, intensity))
blend_mode = st.selectbox("Blend Mode",
["additive", "average", "maximum", "overlay", "screen"],
index=0)
if st.button("Apply Styles", type="primary", use_container_width=True):
if style_configs:
with st.spinner("Applying styles..."):
progress_bar = st.progress(0)
status_text = st.empty()
# Process with progress updates
for i, (_, key, intensity) in enumerate(style_configs):
model_name = system.cyclegan_models[key]['name']
progress = (i + 1) / len(style_configs)
progress_bar.progress(progress)
status_text.text(f"Applying {model_name}...")
result = system.blend_styles(input_image, style_configs, blend_mode)
st.session_state['last_result'] = result
st.session_state['last_style_configs'] = style_configs
progress_bar.empty()
status_text.empty()
with col2:
st.header("Result")
if 'last_result' in st.session_state:
st.image(st.session_state['last_result'], caption="Styled Image", use_column_width=True)
# Download button
buf = io.BytesIO()
st.session_state['last_result'].save(buf, format='PNG')
st.download_button(
label="Download Result",
data=buf.getvalue(),
file_name=f"styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
mime="image/png"
)
# TAB 2: Regional Transform
with tab2:
st.header("Regional Style Transform")
st.markdown("Paint different regions to apply different styles locally")
# Initialize session state
if 'regions' not in st.session_state:
st.session_state.regions = []
if 'canvas_results' not in st.session_state:
st.session_state.canvas_results = {}
if 'regional_image_original' not in st.session_state:
st.session_state.regional_image_original = None
if 'canvas_ready' not in st.session_state:
st.session_state.canvas_ready = True
if 'last_applied_regions' not in st.session_state:
st.session_state.last_applied_regions = None
if 'canvas_key_base' not in st.session_state:
st.session_state.canvas_key_base = 0
col1, col2 = st.columns([2, 3])
# Define variables at the top level of tab2
use_base = False
base_style = None
base_intensity = 1.0
regional_blend_mode = "additive"
with col1:
# Image source selection
regional_image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True, key="regional_image_source")
if regional_image_source == "Upload":
uploaded_regional = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'], key="regional_upload")
if uploaded_regional:
# Load and store original image
regional_image_original = Image.open(uploaded_regional).convert('RGB')
st.session_state.regional_image_original = regional_image_original
else:
# Use Unsplash image if available
if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
st.session_state.regional_image_original = st.session_state['current_image']
st.success("Using Unsplash image")
else:
st.info("Please search and select an image from the Style Transfer tab first")
if st.session_state.regional_image_original:
# Display the original image
display_img = resize_image_for_display(st.session_state.regional_image_original, max_width=400, max_height=300)
st.image(display_img, caption="Original Image", use_column_width=True)
st.subheader("Define Regions")
# Base style (optional)
with st.expander("Base Style (Optional)", expanded=False):
use_base = st.checkbox("Apply base style to entire image")
if use_base:
base_style = st.selectbox("Base style", style_choices, key="base_style")
base_intensity = st.slider("Base intensity", 0.0, 2.0, 1.0, key="base_intensity")
# Region management
col_btn1, col_btn2, col_btn3 = st.columns(3)
with col_btn1:
if st.button("Add Region", use_container_width=True):
new_region = {
'id': len(st.session_state.regions),
'style': style_choices[0] if style_choices else None,
'intensity': 1.0,
'color': f"hsla({len(st.session_state.regions) * 60}, 70%, 50%, 0.5)"
}
st.session_state.regions.append(new_region)
st.session_state.canvas_ready = True
st.rerun()
with col_btn2:
if st.button("Clear All", use_container_width=True):
st.session_state.regions = []
st.session_state.canvas_results = {}
if 'regional_result' in st.session_state:
del st.session_state['regional_result']
st.session_state.canvas_ready = True
st.session_state.canvas_key_base = 0
st.rerun()
with col_btn3:
if st.button("Reset Result", use_container_width=True):
if 'regional_result' in st.session_state:
del st.session_state['regional_result']
st.session_state.canvas_ready = True
st.rerun()
# Configure each region
for i, region in enumerate(st.session_state.regions):
with st.expander(f"Region {i+1} - {region.get('style', 'None')}", expanded=(i == len(st.session_state.regions) - 1)):
col_a, col_b = st.columns(2)
with col_a:
new_style = st.selectbox(
"Style",
style_choices,
key=f"region_style_{i}",
index=style_choices.index(region['style']) if region['style'] in style_choices else 0
)
region['style'] = new_style
with col_b:
region['intensity'] = st.slider(
"Intensity",
0.0, 2.0,
region.get('intensity', 1.0),
key=f"region_intensity_{i}"
)
if st.button(f"Remove Region {i+1}", key=f"remove_region_{i}"):
# Remove the region
st.session_state.regions.pop(i)
# Rebuild canvas results with proper indices
old_canvas_results = st.session_state.canvas_results.copy()
st.session_state.canvas_results = {}
for old_idx, result in old_canvas_results.items():
if old_idx < i:
# Keep results before removed index
st.session_state.canvas_results[old_idx] = result
elif old_idx > i:
# Shift results after removed index down by 1
st.session_state.canvas_results[old_idx - 1] = result
st.session_state.canvas_ready = True
st.session_state.canvas_key_base += 1
st.rerun()
# Blend mode
regional_blend_mode = st.selectbox("Blend Mode",
["additive", "average", "maximum", "overlay", "screen"],
index=0, key="regional_blend")
with col2:
if st.session_state.regions and st.session_state.regional_image_original:
st.subheader("Paint Regions")
# Show workflow status
if 'regional_result' in st.session_state:
if st.session_state.canvas_ready:
st.success("Edit Mode - Paint your regions and click 'Apply Regional Styles' when ready")
else:
st.info("Preview Mode - Click 'Continue Editing' to modify regions")
else:
st.info("Paint on the canvas below to define regions for each style")
# Check if we're in edit mode
if not st.session_state.canvas_ready:
# Show a preview of the painted regions
if 'regional_result' in st.session_state:
st.subheader("Current Result")
result_display = resize_image_for_display(st.session_state['regional_result'], max_width=600, max_height=400)
st.image(result_display, caption="Applied Styles", use_column_width=True)
# Create display image
display_image = resize_image_for_display(st.session_state.regional_image_original, max_width=600, max_height=400)
display_width, display_height = display_image.size
# Info message
st.info(f"Image resized to {display_width}x{display_height} for display. Original resolution will be used for processing.")
# Get current region
current_region_idx = st.selectbox(
"Select region to paint",
range(len(st.session_state.regions)),
format_func=lambda x: f"Region {x+1}: {st.session_state.regions[x].get('style', 'None')}"
)
current_region = st.session_state.regions[current_region_idx]
col_draw1, col_draw2, col_draw3 = st.columns(3)
with col_draw1:
brush_size = st.slider("Brush Size", 1, 50, 15)
with col_draw2:
drawing_mode = st.selectbox("Tool", ["freedraw", "line", "rect", "circle"])
with col_draw3:
if st.button("Clear This Region"):
if current_region_idx in st.session_state.canvas_results:
del st.session_state.canvas_results[current_region_idx]
st.session_state.canvas_ready = True
st.rerun()
# Create combined background with all previous regions
background_with_regions = display_image.copy()
draw = ImageDraw.Draw(background_with_regions, 'RGBA')
# Draw all regions on the background
for i, region in enumerate(st.session_state.regions):
if i in st.session_state.canvas_results:
canvas_data = st.session_state.canvas_results[i]
if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None:
# Extract mask from canvas data
mask = canvas_data.image_data[:, :, 3] > 0
# Create colored overlay for this region
# Parse HSLA color more carefully
color_str = region['color'].replace('hsla(', '').replace(')', '')
color_parts = color_str.split(',')
hue = int(color_parts[0])
# Convert HSL to RGB (simplified - assumes 70% saturation, 50% lightness)
r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 0.7)
color = (int(r*255), int(g*255), int(b*255))
opacity = 128 if i != current_region_idx else 200
# Draw mask on background
for y in range(mask.shape[0]):
for x in range(mask.shape[1]):
if mask[y, x]:
draw.point((x, y), fill=color + (opacity,))
# Canvas for current region
stroke_color = current_region['color'].replace('0.5)', '0.8)')
# Get initial drawing for current region
initial_drawing = None
if current_region_idx in st.session_state.canvas_results:
canvas_data = st.session_state.canvas_results[current_region_idx]
if canvas_data is not None and hasattr(canvas_data, 'json_data'):
initial_drawing = canvas_data.json_data
canvas_result = st_canvas(
fill_color=stroke_color,
stroke_width=brush_size,
stroke_color=stroke_color,
background_image=background_with_regions,
update_streamlit=True,
height=display_height,
width=display_width,
drawing_mode=drawing_mode,
display_toolbar=True,
initial_drawing=initial_drawing,
key=f"regional_canvas_{current_region_idx}_{brush_size}_{drawing_mode}"
)
# Save canvas result
if canvas_result:
st.session_state.canvas_results[current_region_idx] = canvas_result
# Apply button
if st.button("Apply Regional Styles", type="primary", use_container_width=True):
with st.spinner("Applying regional styles..."):
# Create combined mask from all canvas results
combined_mask = combine_region_masks(
[st.session_state.canvas_results.get(i) for i in range(len(st.session_state.regions))],
(display_height, display_width)
)
# Prepare base style configs if enabled
base_configs = None
if use_base and base_style:
base_key = None
for key, info in system.cyclegan_models.items():
if info['name'] == base_style:
base_key = key
break
if base_key:
base_configs = [('cyclegan', base_key, base_intensity)]
# Apply regional styles using original image
result = system.apply_regional_styles(
st.session_state.regional_image_original, # Use original resolution
combined_mask,
st.session_state.regions,
base_configs,
regional_blend_mode
)
st.session_state['regional_result'] = result
# Show result with fixed size
if 'regional_result' in st.session_state:
st.subheader("Result")
# Add display size control
display_size = st.slider("Display Size", 300, 800, 600, 50, key="regional_display_size")
# Fixed size display
result_display = resize_image_for_display(
st.session_state['regional_result'],
max_width=display_size,
max_height=display_size
)
st.image(result_display, caption="Regional Styled Image")
# Show actual dimensions
st.caption(f"Original size: {st.session_state['regional_result'].size[0]}x{st.session_state['regional_result'].size[1]} pixels")
# Download button
buf = io.BytesIO()
st.session_state['regional_result'].save(buf, format='PNG')
st.download_button(
label="Download Result",
data=buf.getvalue(),
file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
mime="image/png"
)
# TAB 3: Video Processing
with tab3:
st.header("Video Processing")
if not VIDEO_PROCESSING_AVAILABLE:
st.warning("""
Video processing requires OpenCV to be installed.
To enable video processing, add `opencv-python` to your requirements.txt
""")
else:
col1, col2 = st.columns(2)
with col1:
video_file = st.file_uploader("Upload Video", type=['mp4', 'avi', 'mov'])
if video_file:
st.video(video_file)
st.subheader("Style Configuration")
# Style selection (up to 2 for videos)
video_styles = []
for i in range(2):
with st.expander(f"Style {i+1}", expanded=(i==0)):
style = st.selectbox(f"Select style", style_choices, key=f"video_style_{i}")
intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"video_intensity_{i}")
if style and intensity > 0:
model_key = None
for key, info in system.cyclegan_models.items():
if info['name'] == style:
model_key = key
break
if model_key:
video_styles.append(('cyclegan', model_key, intensity))
video_blend_mode = st.selectbox("Blend Mode",
["additive", "average", "maximum", "overlay", "screen"],
index=0, key="video_blend")
if st.button("Process Video", type="primary", use_container_width=True):
if video_styles:
with st.spinner("Processing video..."):
progress_bar = st.progress(0)
status_text = st.empty()
def progress_callback(p, msg):
progress_bar.progress(p)
status_text.text(msg)
# Save uploaded file temporarily
temp_input = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(video_file.name)[1])
temp_input.write(video_file.read())
temp_input.close()
# Process video
output_path = system.video_processor.process_video(
temp_input.name, video_styles, video_blend_mode, progress_callback
)
if output_path and os.path.exists(output_path):
# Read the video file immediately
try:
with open(output_path, 'rb') as f:
video_bytes = f.read()
# Determine file extension
file_ext = os.path.splitext(output_path)[1].lower()
# Store in session state
st.session_state['video_result_bytes'] = video_bytes
st.session_state['video_result_ext'] = file_ext
st.session_state['video_result_available'] = True
st.session_state['video_is_mp4'] = (file_ext == '.mp4')
st.success(f"Video processing complete! Format: {file_ext.upper()}")
# Clean up files
try:
os.unlink(output_path)
except:
pass
except Exception as e:
st.error(f"Failed to read processed video: {str(e)}")
st.session_state['video_result_available'] = False
else:
st.error("Failed to process video. Please try a different video or reduce the resolution.")
st.session_state['video_result_available'] = False
# Cleanup input file
try:
os.unlink(temp_input.name)
except:
pass
progress_bar.empty()
status_text.empty()
else:
st.warning("Please select at least one style")
with col2:
st.header("Result")
if st.session_state.get('video_result_available', False) and 'video_result_bytes' in st.session_state:
try:
file_ext = st.session_state.get('video_result_ext', '.mp4')
video_bytes = st.session_state['video_result_bytes']
# File info
file_size_mb = len(video_bytes) / (1024 * 1024)
# Try to display video
if st.session_state.get('video_is_mp4', False):
# For MP4, should work in browser
st.video(video_bytes)
st.success(f"Video ready! Size: {file_size_mb:.2f} MB")
else:
# For non-MP4, show info and download
st.info(f"Video format ({file_ext}) may not play in browser. Please download to view.")
# Show preview image if possible
try:
# Try to extract a frame for preview
temp_preview = tempfile.NamedTemporaryFile(delete=False, suffix=file_ext)
temp_preview.write(video_bytes)
temp_preview.close()
cap = cv2.VideoCapture(temp_preview.name)
ret, frame = cap.read()
if ret:
# Convert frame to RGB and display
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
st.image(rgb_frame, caption="Video Preview (First Frame)", use_column_width=True)
cap.release()
os.unlink(temp_preview.name)
except:
pass
# Always provide download button
mime_types = {
'.mp4': 'video/mp4',
'.avi': 'video/x-msvideo',
'.mov': 'video/quicktime'
}
mime_type = mime_types.get(file_ext, 'application/octet-stream')
col_dl1, col_dl2 = st.columns(2)
with col_dl1:
st.download_button(
label=f"Download Video ({file_ext.upper()})",
data=video_bytes,
file_name=f"styled_video_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}{file_ext}",
mime=mime_type,
use_container_width=True
)
with col_dl2:
if st.button("Clear Result", use_container_width=True):
del st.session_state['video_result_bytes']
st.session_state['video_result_available'] = False
if 'video_result_ext' in st.session_state:
del st.session_state['video_result_ext']
if 'video_is_mp4' in st.session_state:
del st.session_state['video_is_mp4']
st.rerun()
# Info about playback
if not st.session_state.get('video_is_mp4', False):
st.caption("For best compatibility, download and use VLC or another video player.")
except Exception as e:
st.error(f"Error displaying video: {str(e)}")
# Emergency download button
if 'video_result_bytes' in st.session_state:
st.download_button(
label="📥 Download Video (Error occurred)",
data=st.session_state['video_result_bytes'],
file_name=f"styled_video_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4",
mime="application/octet-stream"
)
elif st.session_state.get('video_result_available', False):
st.warning("Video data not found. Please process the video again.")
if st.button("Clear State"):
st.session_state['video_result_available'] = False
st.rerun()
# TAB 4: Training with AdaIN and Regional Application
# TAB 4: Training with AdaIN and Regional Application
with tab4:
st.header("Train Custom Style with AdaIN")
st.markdown("Train your own style transfer model using Adaptive Instance Normalization")
# Initialize session state for content images
if 'content_images_list' not in st.session_state:
st.session_state.content_images_list = []
if 'adain_canvas_result' not in st.session_state:
st.session_state.adain_canvas_result = None
if 'adain_test_image' not in st.session_state:
st.session_state.adain_test_image = None
col1, col2, col3 = st.columns([1, 1, 1])
with col1:
st.subheader("Style Images")
style_imgs = st.file_uploader("Upload 1-5 style images", type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True, key="train_style_adain")
if style_imgs:
st.markdown(f"**{len(style_imgs)} style image(s) uploaded**")
# Display style images in a grid
style_cols = st.columns(min(len(style_imgs), 3))
for idx, style_img in enumerate(style_imgs[:3]):
with style_cols[idx % 3]:
img = Image.open(style_img).convert('RGB')
st.image(img, caption=f"Style {idx+1}", use_column_width=True)
if len(style_imgs) > 3:
st.caption(f"... and {len(style_imgs) - 3} more")
with col2:
st.subheader("Content Images")
content_imgs = st.file_uploader("Upload content images (10-50 recommended)",
type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True,
key="train_content_adain")
if content_imgs:
st.markdown(f"**{len(content_imgs)} content image(s) uploaded**")
# Store content images in session state for later use
st.session_state.content_images_list = content_imgs
# Display content images in a grid
content_cols = st.columns(min(len(content_imgs), 3))
for idx, content_img in enumerate(content_imgs[:3]):
with content_cols[idx % 3]:
img = Image.open(content_img).convert('RGB')
st.image(img, caption=f"Content {idx+1}", use_column_width=True)
if len(content_imgs) > 3:
st.caption(f"... and {len(content_imgs) - 3} more")
with col3:
st.subheader("Training Settings")
model_name = st.text_input("Model Name",
value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
# IMPROVED DEFAULT VALUES
epochs = st.slider("Training Epochs", 10, 100, 50, 5) # Increased default
batch_size = st.slider("Batch Size", 1, 8, 4)
learning_rate = st.number_input("Learning Rate", 0.00001, 0.001, 0.0001, format="%.5f")
with st.expander("Advanced Settings"):
# MUCH HIGHER STYLE WEIGHT BY DEFAULT
style_weight = st.number_input("Style Weight", 1.0, 1000.0, 100.0, 10.0)
content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1)
save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20, 10, 5)
st.info("💡 **Pro tip**: For better quality, use Style Weight 100-500x higher than Content Weight")
st.markdown("---")
# Training button
if st.button("Start AdaIN Training", type="primary", use_container_width=True):
if style_imgs and content_imgs:
if len(content_imgs) < 10:
st.warning("For best results, use at least 10 content images")
with st.spinner("Training AdaIN model..."):
progress_bar = st.progress(0)
status_text = st.empty()
def progress_callback(p, msg):
progress_bar.progress(p)
status_text.text(msg)
# Create temp directory for content images
temp_content_dir = f'/tmp/content_images_{uuid.uuid4().hex}'
os.makedirs(temp_content_dir, exist_ok=True)
# Save content images
for idx, img_file in enumerate(content_imgs):
img = Image.open(img_file).convert('RGB')
img.save(os.path.join(temp_content_dir, f'content_{idx}.jpg'))
# Load style images
style_images = []
for style_file in style_imgs:
style_img = Image.open(style_file).convert('RGB')
style_images.append(style_img)
# IMPROVED TRAINING FUNCTION
# Multi-layer VGG loss for better quality
class MultiLayerVGG(nn.Module):
def __init__(self):
super().__init__()
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1
self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1
self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1
self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
h1 = self.slice1(x)
h2 = self.slice2(h1)
h3 = self.slice3(h2)
h4 = self.slice4(h3)
return [h1, h2, h3, h4]
# Create model
model = AdaINStyleTransfer().to(system.device)
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
print(f"Training AdaIN model at 512x512 resolution")
print(f"Training device: {system.device}")
# Prepare style images - LARGER SIZE
style_transform = transforms.Compose([
transforms.Resize(600), # Increased size
transforms.RandomCrop(512), # Larger crops
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
style_tensors = []
# Create multiple augmented versions
for style_img in style_images:
for _ in range(5): # 5 augmented versions per style
style_tensor = style_transform(style_img).unsqueeze(0).to(system.device)
style_tensors.append(style_tensor)
# Prepare content dataset - LARGER SIZE
content_transform = transforms.Compose([
transforms.Resize(600),
transforms.RandomCrop(512),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
dataset = StyleTransferDataset(temp_content_dir, transform=content_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
# Multi-layer loss network
loss_network = MultiLayerVGG().to(system.device).eval()
mse_loss = nn.MSELoss()
# Training loop
model.train()
model.encoder.eval()
total_steps = 0
# Multiply style weight for better results
actual_style_weight = style_weight * 10
for epoch in range(epochs):
epoch_loss = 0
epoch_content_loss = 0
epoch_style_loss = 0
for batch_idx, content_batch in enumerate(dataloader):
content_batch = content_batch.to(system.device)
# Randomly select style images
batch_style = []
for _ in range(content_batch.size(0)):
style_idx = np.random.randint(0, len(style_tensors))
batch_style.append(style_tensors[style_idx])
batch_style = torch.cat(batch_style, dim=0)
# Forward pass
output = model(content_batch, batch_style)
# Multi-layer loss
with torch.no_grad():
content_feats = loss_network(content_batch)
style_feats = loss_network(batch_style)
output_feats = loss_network(output)
# Content loss from relu4_1
content_loss = mse_loss(output_feats[-1], content_feats[-1])
# Style loss from multiple layers
style_loss = 0
style_weights = [0.2, 0.3, 0.5, 1.0]
def gram_matrix(feat):
b, c, h, w = feat.size()
feat = feat.view(b, c, h * w)
gram = torch.bmm(feat, feat.transpose(1, 2))
return gram / (c * h * w)
for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)):
output_gram = gram_matrix(output_feat)
style_gram = gram_matrix(style_feat)
style_loss += weight * mse_loss(output_gram, style_gram)
style_loss /= len(style_weights)
# Total loss
loss = content_weight * content_loss + actual_style_weight * style_loss
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0)
optimizer.step()
epoch_loss += loss.item()
epoch_content_loss += content_loss.item()
epoch_style_loss += style_loss.item()
total_steps += 1
# Progress callback
if progress_callback and total_steps % 10 == 0:
progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
progress_callback(progress,
f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} "
f"(C: {content_loss.item():.4f}, S: {style_loss.item():.4f})")
# Step scheduler
scheduler.step()
# Print epoch stats
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, "
f"Content={epoch_content_loss/len(dataloader):.4f}, "
f"Style={epoch_style_loss/len(dataloader):.4f}")
# Save checkpoint
if (epoch + 1) % save_interval == 0:
checkpoint_path = f'{system.models_dir}/{model_name}_epoch_{epoch+1}.pth'
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': avg_loss,
'model_type': 'adain'
}, checkpoint_path)
print(f"Saved checkpoint: {checkpoint_path}")
# Save final model
final_path = f'{system.models_dir}/{model_name}_final.pth'
torch.save({
'model_state_dict': model.state_dict(),
'model_type': 'adain'
}, final_path)
# Cleanup
shutil.rmtree(temp_content_dir)
if model:
st.session_state['trained_adain_model'] = model
st.session_state['trained_style_images'] = style_images
st.session_state['model_path'] = final_path
st.success("AdaIN training complete! 🎉")
# Add to system's models
system.lightweight_models[model_name] = model
progress_bar.empty()
status_text.empty()
else:
st.error("Please upload both style and content images")
# Testing section with regional application
if 'trained_adain_model' in st.session_state:
st.markdown("---")
st.header("Test Your AdaIN Model")
# Application mode selection
application_mode = st.radio("Application Mode",
["Whole Image", "Paint Region"],
horizontal=True,
help="Choose whether to apply style to entire image or paint specific regions")
test_col1, test_col2, test_col3 = st.columns([1, 1, 1])
with test_col1:
st.subheader("Test Options")
# Test image selection
test_source = st.radio("Test Image Source",
["Use Content Image", "Upload New", "Use Unsplash Image"],
horizontal=True)
test_image = None
if test_source == "Use Content Image" and st.session_state.content_images_list:
# Select from uploaded content images
content_idx = st.selectbox("Select content image",
range(len(st.session_state.content_images_list)),
format_func=lambda x: f"Content Image {x+1}")
test_image = Image.open(st.session_state.content_images_list[content_idx]).convert('RGB')
elif test_source == "Use Unsplash Image":
# Use current Unsplash image if available
if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
test_image = st.session_state['current_image']
st.success("Using Unsplash image")
else:
st.info("Please search and select an image from the Style Transfer tab first")
else:
# Upload new image
test_upload = st.file_uploader("Upload test image",
type=['png', 'jpg', 'jpeg'],
key="test_adain")
if test_upload:
test_image = Image.open(test_upload).convert('RGB')
# Store test image in session state
if test_image:
st.session_state['adain_test_image'] = test_image
# Style selection for testing
if 'trained_style_images' in st.session_state and len(st.session_state['trained_style_images']) > 1:
style_idx = st.selectbox("Select style",
range(len(st.session_state['trained_style_images'])),
format_func=lambda x: f"Style {x+1}")
test_style = st.session_state['trained_style_images'][style_idx]
elif 'trained_style_images' in st.session_state:
test_style = st.session_state['trained_style_images'][0]
st.info("Using the single trained style")
else:
test_style = None
# IMPROVED DEFAULTS
# Alpha blending control
alpha = st.slider("Style Strength (Alpha)", 0.0, 2.0, 1.2, 0.1,
help="0 = original content, 1 = full style transfer, >1 = stronger style")
# Add tiling option - DEFAULT TO TRUE
use_tiling = st.checkbox("Use Tiled Processing",
value=True, # Default to True
help="Process images in tiles for better quality. Recommended for ALL images.")
# Initialize variables with default values
brush_size = 30
drawing_mode = "freedraw"
feather_radius = 10
# Regional painting options (only show if in paint mode)
if application_mode == "Paint Region":
st.markdown("---")
st.subheader("Painting Options")
brush_size = st.slider("Brush Size", 5, 100, 30)
drawing_mode = st.selectbox("Drawing Tool",
["freedraw", "line", "rect", "circle", "polygon"],
index=0)
# Feather/blur the mask edges
feather_radius = st.slider("Edge Softness", 0, 50, 10,
help="Blur mask edges for smoother transitions")
col_btn1, col_btn2 = st.columns(2)
with col_btn1:
if st.button("Clear Canvas", use_container_width=True):
st.session_state['adain_canvas_result'] = None
st.rerun()
with col_btn2:
if st.button("Reset Result", use_container_width=True):
if 'adain_styled_result' in st.session_state:
del st.session_state['adain_styled_result']
st.rerun()
with test_col2:
st.subheader("Canvas / Original")
if application_mode == "Paint Region" and test_image:
# Show canvas for painting
display_img = resize_image_for_display(test_image, max_width=400, max_height=400)
canvas_width, canvas_height = display_img.size
st.info("Paint the areas where you want to apply the style")
# Canvas for painting mask
canvas_result = st_canvas(
fill_color="rgba(255, 0, 0, 0.3)", # Red with transparency
stroke_width=brush_size,
stroke_color="rgba(255, 0, 0, 0.5)",
background_image=display_img,
update_streamlit=True,
height=canvas_height,
width=canvas_width,
drawing_mode=drawing_mode,
display_toolbar=True,
key=f"adain_canvas_{brush_size}_{drawing_mode}"
)
# Save canvas result
if canvas_result:
st.session_state['adain_canvas_result'] = canvas_result
# Show style image below canvas
if test_style:
st.markdown("---")
st.image(test_style, caption="Style Image", use_column_width=True)
else:
# Show original images
if test_image:
st.image(test_image, caption="Content Image", use_column_width=True)
if test_style:
st.image(test_style, caption="Style Image", use_column_width=True)
with test_col3:
st.subheader("Result")
# Apply button
apply_button = st.button("Apply Style", type="primary", use_container_width=True)
if apply_button and test_image and test_style:
with st.spinner("Applying style..."):
if application_mode == "Whole Image":
# Apply to whole image
result = system.apply_adain_style(
test_image,
test_style,
st.session_state['trained_adain_model'],
alpha=alpha,
use_tiling=use_tiling
)
else:
# Apply to painted region
result = apply_adain_regional(
test_image,
test_style,
st.session_state['trained_adain_model'],
st.session_state.get('adain_canvas_result'),
alpha=alpha,
feather_radius=feather_radius,
use_tiling=use_tiling
)
if result:
st.session_state['adain_styled_result'] = result
# Show result if available
if 'adain_styled_result' in st.session_state:
st.image(st.session_state['adain_styled_result'],
caption="Styled Result",
use_column_width=True)
# Quality tips
with st.expander("💡 Tips for Better Quality"):
st.markdown("""
- **Always use tiling** for best quality
- Try **alpha > 1.0** (1.2-1.5) for stronger style
- Use **multiple style images** when training
- Train for **50+ epochs** for best results
- If quality is still poor, retrain with **style weight = 200-500**
""")
# Download button
buf = io.BytesIO()
st.session_state['adain_styled_result'].save(buf, format='PNG')
st.download_button(
label="Download Result",
data=buf.getvalue(),
file_name=f"adain_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
mime="image/png"
)
# Model download section
st.markdown("---")
if 'model_path' in st.session_state and os.path.exists(st.session_state['model_path']):
col_dl1, col_dl2 = st.columns(2)
with col_dl1:
with open(st.session_state['model_path'], 'rb') as f:
st.download_button(
label="Download Trained AdaIN Model",
data=f.read(),
file_name=f"{model_name}_final.pth",
mime="application/octet-stream",
use_container_width=True
)
with col_dl2:
st.info("This model can be loaded and used for real-time style transfer")
# Add this helper function (place it before the tab or with other helper functions)
def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False):
"""Apply AdaIN style transfer to a painted region only"""
if content_image is None or style_image is None or model is None:
return None
try:
# Get the mask from canvas
if canvas_result is None or canvas_result.image_data is None:
# No mask painted, apply to whole image
return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
# Extract mask from canvas
mask_data = canvas_result.image_data[:, :, 3] # Alpha channel
mask = mask_data > 0
# Resize mask to match original image size
original_size = content_image.size
display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0])
if original_size != display_size:
# Convert mask to PIL image for resizing
mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
mask_pil = mask_pil.resize(original_size, Image.NEAREST)
mask = np.array(mask_pil) > 128
# Apply feathering to mask edges if requested
if feather_radius > 0:
from scipy.ndimage import gaussian_filter
mask_float = mask.astype(np.float32)
mask_float = gaussian_filter(mask_float, sigma=feather_radius)
mask_float = np.clip(mask_float, 0, 1)
else:
mask_float = mask.astype(np.float32)
# Apply style to entire image with tiling option
styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
if styled_full is None:
return None
# Blend original and styled based on mask
original_array = np.array(content_image, dtype=np.float32)
styled_array = np.array(styled_full, dtype=np.float32)
# Expand mask to 3 channels
mask_3ch = np.stack([mask_float] * 3, axis=2)
# Blend
result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch
result_array = np.clip(result_array, 0, 255).astype(np.uint8)
return Image.fromarray(result_array)
except Exception as e:
print(f"Error applying regional AdaIN style: {e}")
traceback.print_exc()
return None
# TAB 5: Batch Processing
with tab5:
st.header("Batch Processing")
col1, col2 = st.columns(2)
with col1:
# Image source selection for batch
batch_source = st.radio("Image Source", ["Upload Multiple", "Use Current Unsplash Image"], horizontal=True, key="batch_source")
batch_files = []
if batch_source == "Upload Multiple":
batch_files = st.file_uploader("Upload Images", type=['png', 'jpg', 'jpeg'],
accept_multiple_files=True, key="batch_upload")
else:
# Use current Unsplash image if available
if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
batch_files = [st.session_state['current_image']]
st.success("Using current Unsplash image for batch processing")
else:
st.info("Please search and select an image from the Style Transfer tab first")
processing_type = st.radio("Processing Type", ["CycleGAN", "Custom Trained Model"])
if processing_type == "CycleGAN":
# Style configuration
batch_styles = []
for i in range(3):
with st.expander(f"Style {i+1}", expanded=(i==0)):
style = st.selectbox(f"Select style", style_choices, key=f"batch_style_{i}")
intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"batch_intensity_{i}")
if style and intensity > 0:
model_key = None
for key, info in system.cyclegan_models.items():
if info['name'] == style:
model_key = key
break
if model_key:
batch_styles.append(('cyclegan', model_key, intensity))
batch_blend_mode = st.selectbox("Blend Mode",
["additive", "average", "maximum", "overlay", "screen"],
index=0, key="batch_blend")
else:
# Custom model upload
custom_model_file = st.file_uploader("Upload Trained Model (.pth)", type=['pth'])
if st.button("Process Batch", type="primary", use_container_width=True):
if batch_files:
with st.spinner("Processing batch..."):
progress_bar = st.progress(0)
processed_images = []
if processing_type == "CycleGAN" and batch_styles:
for idx, file in enumerate(batch_files):
progress_bar.progress((idx + 1) / len(batch_files))
# Handle both file uploads and PIL images
if isinstance(file, Image.Image):
image = file
else:
image = Image.open(file).convert('RGB')
result = system.blend_styles(image, batch_styles, batch_blend_mode)
processed_images.append(result)
elif processing_type == "Custom Trained Model" and custom_model_file:
# Load custom model
temp_model = tempfile.NamedTemporaryFile(delete=False, suffix='.pth')
temp_model.write(custom_model_file.read())
temp_model.close()
model = system.load_lightweight_model(temp_model.name)
if model:
for idx, file in enumerate(batch_files):
progress_bar.progress((idx + 1) / len(batch_files))
# Handle both file uploads and PIL images
if isinstance(file, Image.Image):
image = file
else:
image = Image.open(file).convert('RGB')
result = system.apply_lightweight_style(image, model)
if result:
processed_images.append(result)
os.unlink(temp_model.name)
if processed_images:
# Create zip
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zf:
for idx, img in enumerate(processed_images):
img_buffer = io.BytesIO()
img.save(img_buffer, format='PNG')
zf.writestr(f"styled_{idx+1:03d}.png", img_buffer.getvalue())
st.session_state['batch_results'] = processed_images
st.session_state['batch_zip'] = zip_buffer.getvalue()
progress_bar.empty()
with col2:
if 'batch_results' in st.session_state:
st.header("Results")
# Show gallery
cols = st.columns(4)
for idx, img in enumerate(st.session_state['batch_results'][:8]):
cols[idx % 4].image(img, use_column_width=True)
if len(st.session_state['batch_results']) > 8:
st.info(f"Showing 8 of {len(st.session_state['batch_results'])} processed images")
# Download zip
st.download_button(
label="Download All (ZIP)",
data=st.session_state['batch_zip'],
file_name=f"batch_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.zip",
mime="application/zip"
)
# TAB 6: Documentation
with tab6:
st.markdown(f"""
## Style Transfer System Documentation
### Available CycleGAN Models
This system includes pre-trained bidirectional CycleGAN models:
{chr(10).join([f'- **{info["name"]}**' for key, info in sorted(system.cyclegan_models.items(), key=lambda item: item[1]["name"])])}
### Features
#### Style Transfer
- Apply multiple styles simultaneously
- Adjustable intensity for each style
- Multiple blending modes for creative effects
- **NEW**: Search and use images from Unsplash
#### Regional Transform
- Paint specific regions to apply different styles
- Support for multiple regions with different styles
- Adjustable brush size and drawing tools
- Base style + regional overlays
- Persistent brush strokes across regions
- Optimized display for large images
#### Video Processing
- Frame-by-frame style transfer
- Maintains temporal consistency
- Supports all style combinations and blend modes
- Enhanced codec compatibility
#### Custom Training
- Train on any artistic style with minimal data (1-50 images)
- Automatic data augmentation for small datasets
- Adjustable model complexity (3-12 residual blocks)
### Model Architecture
- **CycleGAN models**: 9-12 residual blocks for high-quality transformations
- **Lightweight models**: 3-12 residual blocks (customizable during training)
- **Training approach**: Unpaired image-to-image translation
### Technical Details
- **Framework**: PyTorch
- **GPU Support**: CUDA acceleration when available
- **Image Formats**: JPG, PNG, BMP
- **Video Formats**: MP4, AVI, MOV
- **Model Size**: ~45MB (CycleGAN), 5-15MB (Lightweight)
### Unsplash Integration
To use Unsplash image search:
1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers)
2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY`
3. Search for images directly in the app
4. Automatic attribution for photographers
### Usage Tips
1. **For best results**: Use high-quality input images
2. **Style intensity**: Start with 1.0, adjust to taste
3. **Blending modes**:
- 'Additive' for bold effects
- 'Average' for subtle blends
- 'Overlay' for dramatic contrasts
4. **Regional painting**:
- Use larger brush for smooth transitions
- Multiple thin layers work better than one thick layer
- Previous regions remain visible as you paint new ones
5. **Custom training**: More diverse content images = better generalization
6. **Video processing**: Keep videos under 30 seconds for faster processing
### Regional Transform Guide
The regional transform feature allows you to:
1. Define multiple regions by painting on the canvas
2. Assign different styles to each region
3. Control intensity per region
4. Apply an optional base style to the entire image
5. Blend regions using various modes
**Tips for Regional Transform:**
- Start with a base style for overall coherence
- Use semi-transparent brushes for smoother transitions
- Overlap regions for interesting blend effects
- Experiment with different blend modes per region
- All regions are visible while painting for better control
""")
# Footer
st.markdown("---")
st.markdown("Style transfer system with CycleGAN models and regional painting capabilities.")