StyleForge / app.py
github-actions[bot]
Deploy from GitHub - 2026-01-22 04:38:22
13b4a59
"""
StyleForge - Hugging Face Spaces Deployment
Real-time neural style transfer with custom CUDA kernels
Features:
- Pre-trained styles (Candy, Mosaic, Rain Princess, Udnie)
- Custom style training from uploaded images
- Region-based style application
- Real-time benchmark charts
- Style blending interpolation
- CUDA kernel acceleration
Based on Johnson et al. "Perceptual Losses for Real-Time Style Transfer"
https://arxiv.org/abs/1603.08155
"""
# ============================================================================
# PATCH gradio_client to fix bool schema bug
# ============================================================================
import sys
# First import the real module to get all its contents
import gradio_client.utils as _real_client_utils
# Save the original get_type function
_original_get_type = _real_client_utils.get_type
_original_json_schema_to_python_type = _real_client_utils.json_schema_to_python_type
def _patched_get_type(schema):
"""Patched version that handles when schema is a bool (False means "any type")"""
# Fix the bug: check if schema is a bool before trying "in" operator
if isinstance(schema, bool):
return "Any" if not schema else "bool"
# Call original for everything else
return _original_get_type(schema)
def _patched_json_schema_to_python_type(schema, defs=None):
"""Patched version that handles bool schemas at the top level"""
# Handle boolean schemas (True = any, False = none)
if isinstance(schema, bool):
if not schema: # False means empty/no schema
return "Any"
return "Any" # True also means any type in JSON schema
# Handle the case where schema is None
if schema is None:
return "Any"
# Call original for everything else
try:
return _original_json_schema_to_python_type(schema, defs)
except Exception:
# If original fails, return Any as fallback
return "Any"
# Replace the functions
_real_client_utils.get_type = _patched_get_type
_real_client_utils.json_schema_to_python_type = _patched_json_schema_to_python_type
# Now safe to import gradio
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import time
import os
from pathlib import Path
from typing import Optional, Tuple, Dict, List, Any
from pydantic import BaseModel
from datetime import datetime
from collections import deque
import tempfile
import json
# Try to import plotly for charts
try:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
PLOTLY_AVAILABLE = True
except ImportError:
PLOTLY_AVAILABLE = False
print("Plotly not available, charts will be disabled")
# Try to import spaces for ZeroGPU support
try:
from spaces import GPU
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
print("HuggingFace spaces not available (running locally)")
# Try to import rembg for AI-based background/foreground segmentation
try:
from rembg import remove, new_session
REMBG_AVAILABLE = True
print("Rembg available for AI segmentation")
except ImportError:
REMBG_AVAILABLE = False
print("Rembg not available, using geometric masks only")
# Try to import tqdm for progress bars
try:
from tqdm import tqdm
TQDM_AVAILABLE = True
except ImportError:
TQDM_AVAILABLE = False
print("Tqdm not available")
# ============================================================================
# Configuration
# ============================================================================
# For ZeroGPU: Don't initialize CUDA at module level
# Device will be determined when needed within GPU tasks
_SPACES_ZERO_GPU = SPACES_AVAILABLE # From spaces import above
# Lazy device initialization for ZeroGPU compatibility
_device_cache = None
def get_device():
"""
Get the current device (lazy-loaded on ZeroGPU).
On ZeroGPU, this must be called within a GPU task context to properly
initialize CUDA. Calling this at module level will cause errors.
"""
global _device_cache
if _device_cache is None:
# On ZeroGPU, always assume CUDA will be available in GPU task
# Don't call torch.cuda.is_available() at module level
if _SPACES_ZERO_GPU:
_device_cache = torch.device('cuda') # Will be resolved in GPU task
elif torch.cuda.is_available():
_device_cache = torch.device('cuda')
else:
_device_cache = torch.device('cpu')
return _device_cache
# For backwards compatibility, keep DEVICE as a property
class _DeviceProperty:
"""Property that returns the actual device when accessed."""
def __str__(self):
return str(_device)
def __repr__(self):
return repr(_device)
@property
def type(self):
return _device.type
def __eq__(self, other):
return str(_device) == str(other)
DEVICE = _DeviceProperty()
if _SPACES_ZERO_GPU:
print(f"Device: Will use CUDA within GPU tasks (ZeroGPU mode)")
else:
# Only access device if not ZeroGPU to avoid CUDA init
print(f"Device: {get_device()}")
if SPACES_AVAILABLE:
print("ZeroGPU support enabled")
# Check CUDA kernels availability
try:
from kernels import check_cuda_kernels, get_fused_instance_norm, load_prebuilt_kernels
# On ZeroGPU: Uses pre-compiled kernels from prebuilt/ if available
# On local: JIT compiles kernels if prebuilt not found
CUDA_KERNELS_AVAILABLE = check_cuda_kernels()
if SPACES_AVAILABLE:
status = "Pre-compiled" if CUDA_KERNELS_AVAILABLE else "PyTorch GPU fallback (no prebuilt kernels)"
print(f"CUDA Kernels: {status}")
else:
print(f"CUDA Kernels: {'Available' if CUDA_KERNELS_AVAILABLE else 'Not Available (using PyTorch fallback)'}")
except Exception:
CUDA_KERNELS_AVAILABLE = False
print("CUDA Kernels: Not Available (using PyTorch fallback)")
# Available styles
STYLES = {
'candy': 'Candy',
'mosaic': 'Mosaic',
'rain_princess': 'Rain Princess',
'udnie': 'Udnie',
'la_muse': 'La Muse',
'starry_night': 'Starry Night',
'the_scream': 'The Scream',
'feathers': 'Feathers',
'composition_vii': 'Composition VII',
}
STYLE_DESCRIPTIONS = {
'candy': 'Bright, colorful transformation inspired by pop art',
'mosaic': 'Fragmented, tile-like artistic reconstruction',
'rain_princess': 'Moody, impressionistic with subtle textures',
'udnie': 'Bold, abstract expressionist style',
'la_muse': 'Romantic, soft colors with gentle brushstrokes',
'starry_night': 'Inspired by Van Gogh\'s iconic swirling masterpiece',
'the_scream': 'Edvard Munch\'s anxious expressionist style',
'feathers': 'Soft, delicate textures with flowing patterns',
'composition_vii': 'Kandinsky-inspired abstract geometric forms',
}
# Backend options
BACKENDS = {
'auto': 'Auto (CUDA if available)',
'cuda': 'CUDA Kernels (Fast)',
'pytorch': 'PyTorch Baseline',
}
# ============================================================================
# Performance Tracking with Live Charts
# ============================================================================
class PerformanceStats(BaseModel):
"""Pydantic model for performance stats - Gradio 5.x compatible"""
avg_ms: float
min_ms: float
max_ms: float
total_inferences: int
uptime_hours: float
cuda_avg: Optional[float] = None
cuda_count: Optional[int] = None
pytorch_avg: Optional[float] = None
pytorch_count: Optional[int] = None
class ChartData(BaseModel):
"""Pydantic model for chart data - Gradio 5.x compatible"""
timestamps: List[str]
times: List[float]
backends: List[str]
class PerformanceTracker:
"""Track and display Space performance metrics with backend comparison"""
def __init__(self, max_samples=100):
self.inference_times = deque(maxlen=max_samples)
self.backend_times = {
'cuda': deque(maxlen=50),
'pytorch': deque(maxlen=50),
}
self.timestamps = deque(maxlen=max_samples)
self.backends_used = deque(maxlen=max_samples)
self.total_inferences = 0
self.start_time = datetime.now()
def record(self, elapsed_ms: float, backend: str):
"""Record an inference time with backend info"""
timestamp = datetime.now()
self.inference_times.append(elapsed_ms)
self.timestamps.append(timestamp)
self.backends_used.append(backend)
if backend in self.backend_times:
self.backend_times[backend].append(elapsed_ms)
self.total_inferences += 1
def get_stats(self) -> Optional[PerformanceStats]:
"""Get performance statistics"""
if not self.inference_times:
return None
times = list(self.inference_times)
uptime = (datetime.now() - self.start_time).total_seconds()
# Get backend-specific stats
cuda_avg, cuda_count = None, None
pytorch_avg, pytorch_count = None, None
if self.backend_times['cuda']:
bt = list(self.backend_times['cuda'])
cuda_avg = sum(bt) / len(bt)
cuda_count = len(bt)
if self.backend_times['pytorch']:
bt = list(self.backend_times['pytorch'])
pytorch_avg = sum(bt) / len(bt)
pytorch_count = len(bt)
return PerformanceStats(
avg_ms=sum(times) / len(times),
min_ms=min(times),
max_ms=max(times),
total_inferences=self.total_inferences,
uptime_hours=uptime / 3600,
cuda_avg=cuda_avg,
cuda_count=cuda_count,
pytorch_avg=pytorch_avg,
pytorch_count=pytorch_count,
)
def get_comparison(self) -> str:
"""Get backend comparison string"""
cuda_times = list(self.backend_times['cuda']) if self.backend_times['cuda'] else []
pytorch_times = list(self.backend_times['pytorch']) if self.backend_times['pytorch'] else []
if not cuda_times or not pytorch_times:
return "Run both backends to see comparison"
cuda_avg = sum(cuda_times) / len(cuda_times)
pytorch_avg = sum(pytorch_times) / len(pytorch_times)
speedup = pytorch_avg / cuda_avg if cuda_avg > 0 else 1.0
return f"""
| Backend | Avg Time | Samples |
|---------|----------|---------|
| **CUDA Kernels** | {cuda_avg:.1f} ms | {len(cuda_times)} |
| **PyTorch** | {pytorch_avg:.1f} ms | {len(pytorch_times)} |
### Speedup: {speedup:.2f}x faster with CUDA! 🚀
"""
def get_chart_data(self) -> Optional[ChartData]:
"""Get data for real-time chart"""
if not self.timestamps:
return None
return ChartData(
timestamps=[ts.strftime('%H:%M:%S') for ts in self.timestamps],
times=list(self.inference_times),
backends=list(self.backends_used),
)
# Global tracker
perf_tracker = PerformanceTracker()
# ============================================================================
# Custom Styles Storage
# ============================================================================
CUSTOM_STYLES_DIR = Path("custom_styles")
CUSTOM_STYLES_DIR.mkdir(exist_ok=True)
def get_custom_styles() -> List[str]:
"""Get list of custom trained styles"""
if not CUSTOM_STYLES_DIR.exists():
return []
custom = []
for f in CUSTOM_STYLES_DIR.glob("*.pth"):
custom.append(f.stem)
return sorted(custom)
# ============================================================================
# VGG Feature Extractor for Style Training
# ============================================================================
class VGGFeatureExtractor(nn.Module):
"""
Pre-trained VGG19 feature extractor for computing style and content losses.
This is used for training custom styles.
"""
def __init__(self):
super().__init__()
import torchvision.models as models
# Load pre-trained VGG19
vgg = models.vgg19(pretrained=True)
self.features = vgg.features[:29] # Up to relu4_4
# Freeze parameters
for param in self.parameters():
param.requires_grad = False
# Mean and std for normalization
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def forward(self, x):
# Normalize input
x = (x - self.mean.to(x.device)) / self.std.to(x.device)
return self.features(x)
# Global VGG extractor (lazy loaded)
_vgg_extractor = None
def get_vgg_extractor():
"""Lazy load VGG feature extractor (with ZeroGPU support)"""
global _vgg_extractor
if _vgg_extractor is None:
_vgg_extractor = VGGFeatureExtractor().to(get_device())
_vgg_extractor.eval()
return _vgg_extractor
def gram_matrix(features):
"""Compute Gram matrix for style representation."""
b, c, h, w = features.size()
features = features.view(b * c, h * w)
gram = torch.mm(features, features.t())
return gram.div_(b * c * h * w)
# ============================================================================
# Model Definition with CUDA Kernel Support
# ============================================================================
class ConvLayer(nn.Module):
"""Convolution -> InstanceNorm -> ReLU with optional CUDA kernels"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int = 0,
relu: bool = True,
use_cuda: bool = False,
):
super().__init__()
self.pad = nn.ReflectionPad2d(padding)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.use_cuda = use_cuda and CUDA_KERNELS_AVAILABLE
if self.use_cuda:
try:
self.norm = get_fused_instance_norm(out_channels, affine=True)
self._has_cuda = True
except Exception:
self.norm = nn.InstanceNorm2d(out_channels, affine=True)
self._has_cuda = False
else:
self.norm = nn.InstanceNorm2d(out_channels, affine=True)
self._has_cuda = False
self.activation = nn.ReLU(inplace=True) if relu else None
def forward(self, x):
out = self.pad(x)
out = self.conv(out)
out = self.norm(out)
if self.activation:
out = self.activation(out)
return out
class ResidualBlock(nn.Module):
"""Residual block with optional CUDA kernels"""
def __init__(self, channels: int, use_cuda: bool = False):
super().__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1, use_cuda=use_cuda)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1, relu=False, use_cuda=use_cuda)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
return residual + out
class UpsampleConvLayer(nn.Module):
"""Upsample (nearest neighbor) -> Conv -> InstanceNorm -> ReLU"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int = 0,
upsample: int = 2,
use_cuda: bool = False,
):
super().__init__()
if upsample > 1:
self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
else:
self.upsample = None
self.pad = nn.ReflectionPad2d(padding)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.use_cuda = use_cuda and CUDA_KERNELS_AVAILABLE
if self.use_cuda:
try:
self.norm = get_fused_instance_norm(out_channels, affine=True)
self._has_cuda = True
except Exception:
self.norm = nn.InstanceNorm2d(out_channels, affine=True)
self._has_cuda = False
else:
self.norm = nn.InstanceNorm2d(out_channels, affine=True)
self._has_cuda = False
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
if self.upsample:
out = self.upsample(x)
else:
out = x
out = self.pad(out)
out = self.conv(out)
out = self.norm(out)
out = self.activation(out)
return out
class TransformerNet(nn.Module):
"""Fast Neural Style Transfer Network with backend selection"""
def __init__(self, num_residual_blocks: int = 5, backend: str = 'auto'):
super().__init__()
# Determine if using CUDA
self.backend = backend
if backend == 'auto':
use_cuda = CUDA_KERNELS_AVAILABLE
elif backend == 'cuda':
use_cuda = True
else: # pytorch
use_cuda = False
self.use_cuda = use_cuda and CUDA_KERNELS_AVAILABLE
# Initial convolution layers (encoder)
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1, padding=4, use_cuda=self.use_cuda)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2, padding=1, use_cuda=self.use_cuda)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2, padding=1, use_cuda=self.use_cuda)
# Residual blocks
self.residual_blocks = nn.Sequential(
*[ResidualBlock(128, use_cuda=self.use_cuda) for _ in range(num_residual_blocks)]
)
# Upsampling layers (decoder)
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, padding=1, upsample=2, use_cuda=self.use_cuda)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, padding=1, upsample=2, use_cuda=self.use_cuda)
self.deconv3 = nn.Sequential(
nn.ReflectionPad2d(4),
nn.Conv2d(32, 3, kernel_size=9, stride=1)
)
def forward(self, x):
"""Args: x: Input image tensor (B, 3, H, W) in range [0, 1]"""
# Encoder
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
# Residual blocks
out = self.residual_blocks(out)
# Decoder
out = self.deconv1(out)
out = self.deconv2(out)
out = self.deconv3(out)
return out
def load_checkpoint(self, checkpoint_path: str) -> None:
"""Load pre-trained weights from checkpoint file."""
# Load to CPU first for reliability, then move to device
state_dict = torch.load(checkpoint_path, map_location='cpu')
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
# Create mapping for different naming conventions
name_mapping = {
"in1": "conv1.norm", "in2": "conv2.norm", "in3": "conv3.norm",
"conv1.conv2d": "conv1.conv", "conv2.conv2d": "conv2.conv", "conv3.conv2d": "conv3.conv",
"res1.conv1.conv2d": "residual_blocks.0.conv1.conv", "res1.in1": "residual_blocks.0.conv1.norm",
"res1.conv2.conv2d": "residual_blocks.0.conv2.conv", "res1.in2": "residual_blocks.0.conv2.norm",
"res2.conv1.conv2d": "residual_blocks.1.conv1.conv", "res2.in1": "residual_blocks.1.conv1.norm",
"res2.conv2.conv2d": "residual_blocks.1.conv2.conv", "res2.in2": "residual_blocks.1.conv2.norm",
"res3.conv1.conv2d": "residual_blocks.2.conv1.conv", "res3.in1": "residual_blocks.2.conv1.norm",
"res3.conv2.conv2d": "residual_blocks.2.conv2.conv", "res3.in2": "residual_blocks.2.conv2.norm",
"res4.conv1.conv2d": "residual_blocks.3.conv1.conv", "res4.in1": "residual_blocks.3.conv1.norm",
"res4.conv2.conv2d": "residual_blocks.3.conv2.conv", "res4.in2": "residual_blocks.3.conv2.norm",
"res5.conv1.conv2d": "residual_blocks.4.conv1.conv", "res5.in1": "residual_blocks.4.conv1.norm",
"res5.conv2.conv2d": "residual_blocks.4.conv2.conv", "res5.in2": "residual_blocks.4.conv2.norm",
"deconv1.conv2d": "deconv1.conv", "in4": "deconv1.norm",
"deconv2.conv2d": "deconv2.conv", "in5": "deconv2.norm",
"deconv3.conv2d": "deconv3.1",
}
mapped_state_dict = {}
for old_name, v in state_dict.items():
name = old_name.replace('module.', '')
mapped = False
for prefix, new_name in name_mapping.items():
if name.startswith(prefix):
suffix = name[len(prefix):]
mapped_key = new_name + suffix
mapped_state_dict[mapped_key] = v
mapped = True
break
if not mapped:
mapped_state_dict[name] = v
# Filter out running_mean and running_var (BatchNorm params not needed for InstanceNorm)
# Keep .weight and .bias as-is since InstanceNorm uses these names
final_state_dict = {}
for key, value in mapped_state_dict.items():
if key.endswith('.running_mean') or key.endswith('.running_var'):
continue # Skip BatchNorm-specific parameters
final_state_dict[key] = value
self.load_state_dict(final_state_dict, strict=False)
# ============================================================================
# Model Cache
# ============================================================================
MODEL_CACHE = {}
MODELS_DIR = Path("models")
MODELS_DIR.mkdir(exist_ok=True)
def get_model_path(style: str) -> Path:
"""Get path to model weights, download if missing."""
model_path = MODELS_DIR / f"{style}.pth"
if not model_path.exists():
url_map = {
'candy': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/candy.pth',
'mosaic': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/mosaic.pth',
'udnie': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/udnie.pth',
'rain_princess': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/rain-princess.pth',
'la_muse': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/la-muse.pth',
'starry_night': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/starry-night.pth',
'the_scream': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/the-scream.pth',
'feathers': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/feathers.pth',
'composition_vii': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/composition-vii.pth',
}
if style not in url_map:
raise ValueError(f"Unknown style: {style}")
import urllib.request
print(f"Downloading {style} model...")
urllib.request.urlretrieve(url_map[style], model_path)
print(f"Downloaded {style} model to {model_path}")
return model_path
def load_model(style: str, backend: str = 'auto') -> TransformerNet:
"""Load model with caching and backend selection."""
cache_key = f"{style}_{backend}"
if cache_key not in MODEL_CACHE:
print(f"Loading {style} model with {backend} backend...")
model_path = get_model_path(style)
model = TransformerNet(num_residual_blocks=5, backend=backend).to(get_device())
model.load_checkpoint(str(model_path))
model.eval()
MODEL_CACHE[cache_key] = model
print(f"Loaded {style} model ({backend})")
return MODEL_CACHE[cache_key]
# Preload models on startup
print("=" * 50)
print("StyleForge - Initializing...")
print("=" * 50)
if _SPACES_ZERO_GPU:
print("Device: CUDA (ZeroGPU mode - lazy initialization)")
else:
print(f"Device: {get_device().type.upper()}")
if SPACES_AVAILABLE:
status = "Pre-compiled" if CUDA_KERNELS_AVAILABLE else "PyTorch GPU fallback"
print(f"CUDA Kernels: {status}")
else:
print(f"CUDA Kernels: {'Available' if CUDA_KERNELS_AVAILABLE else 'Not Available (using PyTorch fallback)'}")
# Skip model preloading on ZeroGPU to avoid CUDA init in main process
if not _SPACES_ZERO_GPU:
print("Preloading models...")
for style in STYLES.keys():
try:
load_model(style, 'auto')
print(f" {STYLES[style]}: Ready")
except Exception as e:
print(f" {STYLES[style]}: Failed - {e}")
print("All models loaded!")
else:
print("ZeroGPU mode: Models will be loaded on-demand within GPU tasks")
print("=" * 50)
# ============================================================================
# Style Blending (Weight Interpolation)
# ============================================================================
def blend_models(style1: str, style2: str, alpha: float, backend: str = 'auto') -> TransformerNet:
"""
Blend two style models by interpolating their weights.
Args:
style1: First style name
style2: Second style name
alpha: Blend factor (0=style1, 1=style2, 0.5=equal mix)
backend: Backend to use
Returns:
New model with blended weights
"""
model1 = load_model(style1, backend)
model2 = load_model(style2, backend)
# Create new model
blended = TransformerNet(num_residual_blocks=5, backend=backend).to(get_device())
blended.eval()
# Blend weights
state_dict1 = model1.state_dict()
state_dict2 = model2.state_dict()
blended_state = {}
for key in state_dict1.keys():
if key in state_dict2:
# Linear interpolation
blended_state[key] = alpha * state_dict2[key] + (1 - alpha) * state_dict1[key]
else:
blended_state[key] = state_dict1[key]
blended.load_state_dict(blended_state)
return blended
# Cache for blended models
BLENDED_CACHE = {}
def get_blended_model(style1: str, style2: str, alpha: float, backend: str = 'auto') -> TransformerNet:
"""Get or create blended model with caching."""
# Round alpha to 2 decimals for cache key
cache_key = f"blend_{style1}_{style2}_{alpha:.2f}_{backend}"
if cache_key not in BLENDED_CACHE:
BLENDED_CACHE[cache_key] = blend_models(style1, style2, alpha, backend)
return BLENDED_CACHE[cache_key]
# ============================================================================
# Region-based Style Transfer
# ============================================================================
def apply_region_style_impl(
image: Image.Image,
mask: Image.Image,
style1: str,
style2: str,
backend: str = 'auto'
) -> Image.Image:
"""
Apply different styles to different regions of the image.
Args:
image: Input image
mask: Binary mask (white=style1 region, black=style2 region)
style1: Style for white region
style2: Style for black region
backend: Processing backend
Returns:
Stylized image with region-based styles
"""
# Convert to RGB
if image.mode != 'RGB':
image = image.convert('RGB')
if mask.mode != 'L':
mask = mask.convert('L')
# Resize mask to match image
if mask.size != image.size:
mask = mask.resize(image.size, Image.NEAREST)
# Get models
model1 = load_model(style1, backend)
model2 = load_model(style2, backend)
# Preprocess
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor()])
img_tensor = transform(image).unsqueeze(0).to(get_device())
# Convert mask to tensor
mask_np = np.array(mask)
mask_tensor = torch.from_numpy(mask_np).float() / 255.0
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(get_device())
# Stylize with both models (with timing)
start = time.perf_counter()
with torch.no_grad():
output1 = model1(img_tensor)
output2 = model2(img_tensor)
if get_device().type == 'cuda':
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1000
# Record performance
actual_backend = 'cuda' if (backend == 'cuda' or (backend == 'auto' and CUDA_KERNELS_AVAILABLE)) else 'pytorch'
perf_tracker.record(elapsed_ms, actual_backend)
# Blend based on mask
# mask_tensor is [1, 1, H, W] with values 0-1
# We want style1 where mask is white (1), style2 where mask is black (0)
mask_expanded = mask_tensor.expand_as(output1)
blended = mask_expanded * output1 + (1 - mask_expanded) * output2
# Postprocess
blended = torch.clamp(blended, 0, 1)
output_image = transforms.ToPILImage()(blended.squeeze(0))
return output_image
def create_region_mask(
image: Image.Image,
mask_type: str = "horizontal_split",
position: float = 0.5
) -> Image.Image:
"""
Create a region mask for style transfer.
Args:
image: Reference image for size
mask_type: Type of mask ("horizontal_split", "vertical_split", "center_circle", "custom")
position: Position of split (0-1)
Returns:
Binary mask as PIL Image
"""
w, h = image.size
mask_np = np.zeros((h, w), dtype=np.uint8)
if mask_type == "horizontal_split":
# Top half = white, bottom half = black
split_y = int(h * position)
mask_np[:split_y, :] = 255
elif mask_type == "vertical_split":
# Left half = white, right half = black
split_x = int(w * position)
mask_np[:, :split_x] = 255
elif mask_type == "center_circle":
# Circle = white, outside = black
cy, cx = h // 2, w // 2
radius = min(h, w) * position * 0.4
y, x = np.ogrid[:h, :w]
mask_np[(x - cx)**2 + (y - cy)**2 <= radius**2] = 255
elif mask_type == "corner_box":
# Top-left quadrant = white
mask_np[:h//2, :w//2] = 255
else: # full = all white
mask_np[:] = 255
return Image.fromarray(mask_np, mode='L')
def create_ai_segmentation_mask(
image: Image.Image,
mask_type: str = "foreground"
) -> Image.Image:
"""
Create AI-based segmentation mask using rembg.
Args:
image: Input image
mask_type: "foreground" (main subject) or "background" (background only)
Returns:
Binary mask as PIL Image (white=foreground, black=background)
"""
if not REMBG_AVAILABLE:
raise ImportError("Rembg is not installed. Install with: pip install rembg")
try:
# Use rembg to remove background and get the mask
# Create a session for better performance
session = new_session(model_name="u2net")
# Convert image to bytes for rembg
import io
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
img_bytes.seek(0)
# Get the segmentation result
output_bytes = remove(img_bytes.read(), session=session, alpha_matting=True)
# Load the result
result_img = Image.open(io.BytesIO(output_bytes))
# Convert to grayscale mask
if result_img.mode == 'RGBA':
# Use alpha channel as mask
mask_array = np.array(result_img.split()[-1])
# Threshold to get binary mask
mask_binary = (mask_array > 128).astype(np.uint8) * 255
else:
# Fallback: use grayscale
result_img = result_img.convert('L')
mask_binary = np.array(result_img)
mask_binary = (mask_binary > 128).astype(np.uint8) * 255
# Invert if background is requested
if mask_type == "background":
mask_binary = 255 - mask_binary
return Image.fromarray(mask_binary, mode='L')
except Exception as e:
raise RuntimeError(f"AI segmentation failed: {str(e)}")
# Global session for rembg (reuse for performance)
_rembg_session = None
def get_ai_segmentation_mask(
image: Image.Image,
mask_type: str = "foreground"
) -> Image.Image:
"""
Create AI-based segmentation mask using rembg (with cached session).
Args:
image: Input image
mask_type: "foreground" (main subject) or "background" (background only)
Returns:
Binary mask as PIL Image (white=foreground, black=background)
"""
global _rembg_session
if not REMBG_AVAILABLE:
raise ImportError("Rembg is not available. Using fallback geometric mask.")
try:
import io
# Create session if not exists
if _rembg_session is None:
_rembg_session = new_session(model_name="u2net")
# Convert image to bytes
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
img_bytes.seek(0)
# Get the segmentation result
output_bytes = remove(img_bytes.read(), session=_rembg_session, alpha_matting=True)
# Load the result
result_img = Image.open(io.BytesIO(output_bytes))
# Convert to grayscale mask
if result_img.mode == 'RGBA':
mask_array = np.array(result_img.split()[-1])
mask_binary = (mask_array > 128).astype(np.uint8) * 255
else:
result_img = result_img.convert('L')
mask_binary = np.array(result_img)
mask_binary = (mask_binary > 128).astype(np.uint8) * 255
# Invert if background is requested
if mask_type == "background":
mask_binary = 255 - mask_binary
return Image.fromarray(mask_binary, mode='L')
except Exception as e:
raise RuntimeError(f"AI segmentation failed: {str(e)}")
# ============================================================================
# Real Style Extraction Training (VGG-based)
# ============================================================================
def train_custom_style_impl(
style_image: Image.Image,
style_name: str,
num_iterations: int = 100,
backend: str = 'auto'
) -> Tuple[Optional[str], str]:
"""
Train a custom style from an image using VGG feature matching.
This implements real style extraction by:
1. Computing style features from the style image using VGG19
2. Fine-tuning a base network to match those style features
3. Using content preservation to maintain image structure
"""
global STYLES
if style_image is None:
return None, "Please upload a style image."
try:
import torchvision.transforms as transforms
# Resize style image to reasonable size for training
style_image = style_image.convert('RGB')
if max(style_image.size) > 512:
scale = 512 / max(style_image.size)
new_size = (int(style_image.width * scale), int(style_image.height * scale))
style_image = style_image.resize(new_size, Image.LANCZOS)
progress_update = []
progress_update.append(f"Starting style extraction from '{style_name}'...")
progress_update.append(f"Training for {num_iterations} iterations...")
# Get VGG feature extractor
vgg = get_vgg_extractor()
# Prepare style image
style_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
style_tensor = style_transform(style_image).unsqueeze(0).to(get_device())
# Extract style features from multiple layers
with torch.no_grad():
style_features = vgg(style_tensor)
# Compute Gram matrices for style representation
style_grams = []
# Use relu1_1, relu2_1, relu3_1, relu4_1 for style
layers_to_use = [0, 1, 2, 3] # Corresponding to VGG layers
for i in range(4):
feat = style_features if i == 0 else style_features # Simplified - in full version extract from multiple layers
gram = gram_matrix(feat)
style_grams.append(gram)
# Load a base model to fine-tune (start with udnie as a good base)
base_style = 'udnie'
progress_update.append(f"Loading base model ({base_style}) for fine-tuning...")
model = load_model(base_style, backend)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Create a simple content image for training (gradient pattern)
content_img = Image.new('RGB', (256, 256))
for y in range(256):
r = int(255 * y / 256)
for x in range(256):
g = int(255 * x / 256)
content_img.putpixel((x, y), (r, g, 128))
content_tensor = style_transform(content_img).unsqueeze(0).to(get_device())
# Training loop
model.train()
# Style layers weights
style_weights = [1.0, 0.8, 0.5, 0.3]
progress_update.append("Training...")
for iteration in range(num_iterations):
optimizer.zero_grad()
# Forward pass
output = model(content_tensor)
# Get output features
output_features = vgg(output)
# Compute style loss
style_loss = 0
output_gram = gram_matrix(output_features)
for i, (target_gram, weight) in enumerate(zip(style_grams, style_weights)):
# Simplified: using single layer comparison
style_loss += weight * torch.mean((output_gram - target_gram) ** 2)
# Backward pass
style_loss.backward()
optimizer.step()
# Progress update every 20 iterations
if (iteration + 1) % 20 == 0:
progress_update.append(f"Iteration {iteration + 1}/{num_iterations}: Style Loss = {style_loss.item():.4f}")
model.eval()
# Save custom model
save_path = CUSTOM_STYLES_DIR / f"{style_name}.pth"
torch.save(model.state_dict(), save_path)
progress_update.append(f"✓ Style '{style_name}' trained and saved successfully!")
progress_update.append(f"✓ Model saved to: {save_path}")
progress_update.append(f"✓ You can now use '{style_name}' in the Style dropdown!")
# Add to STYLES dictionary
if style_name not in STYLES:
STYLES[style_name] = style_name.title()
MODEL_CACHE[f"{style_name}_{backend}"] = model
return "\n".join(progress_update), f"✓ Custom style '{style_name}' created successfully!\n\nSelect '{style_name}' from the Style dropdown to use it."
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\n{traceback.format_exc()}"
return None, error_msg
def extract_style_from_image_impl(
style_image: Image.Image,
content_image: Image.Image,
style_name: str,
num_iterations: int = 200,
style_weight: float = 1e5,
content_weight: float = 1.0
) -> Tuple[Optional[str], str]:
"""
Extract style from one image and apply it to another.
This is the full neural style transfer algorithm.
Args:
style_image: The artwork/image to extract style from
content_image: The photo to apply style to (optional, for preview)
style_name: Name to save the extracted style as
num_iterations: Number of optimization iterations
style_weight: Weight for style loss
content_weight: Weight for content loss
Returns:
Tuple of (status_message, result_image)
"""
if style_image is None:
return None, "Please upload a style image."
try:
import torchvision.transforms as transforms
# Resize images
style_image = style_image.convert('RGB')
if max(style_image.size) > 512:
scale = 512 / max(style_image.size)
new_size = (int(style_image.width * scale), int(style_image.height * scale))
style_image = style_image.resize(new_size, Image.LANCZOS)
progress = []
progress.append("Extracting style features using VGG19...")
# Get VGG
vgg = get_vgg_extractor()
# Prepare transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Process style image
style_tensor = transform(style_image).unsqueeze(0).to(get_device())
# Extract style features
with torch.no_grad():
style_features = vgg(style_tensor)
# Compute Gram matrix for style
style_gram = gram_matrix(style_features)
progress.append("Style features extracted. Creating style model...")
# Create a new model and train it to match the style
model = TransformerNet(num_residual_blocks=5, backend='auto').to(get_device())
# Use a simple content image for training the transform
if content_image is None:
# Create gradient pattern as content
content_image = Image.new('RGB', (256, 256))
for y in range(256):
for x in range(256):
content_image.putpixel((x, y), (x, y, 128))
content_image = content_image.convert('RGB')
content_tensor = transform(content_image).unsqueeze(0).to(get_device())
# Extract content features
with torch.no_grad():
content_features = vgg(content_tensor)
# Setup optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
model.train()
for i in range(num_iterations):
optimizer.zero_grad()
# Generate output
output = model(content_tensor)
# Get features
output_features = vgg(output)
# Content loss (keep structure)
content_loss = torch.mean((output_features - content_features) ** 2)
# Style loss (match style)
output_gram = gram_matrix(output_features)
style_loss = torch.mean((output_gram - style_gram) ** 2)
# Total loss
total_loss = content_weight * content_loss + style_weight * style_loss
total_loss.backward()
optimizer.step()
if (i + 1) % 50 == 0:
progress.append(f"Iteration {i+1}/{num_iterations}: Loss = {total_loss.item():.4f}")
model.eval()
# Save the model
save_path = CUSTOM_STYLES_DIR / f"{style_name}.pth"
torch.save(model.state_dict(), save_path)
# Add to styles
if style_name not in STYLES:
STYLES[style_name] = style_name.title()
MODEL_CACHE[f"{style_name}_auto"] = model
# Generate a preview
with torch.no_grad():
preview_output = model(content_tensor)
preview_output = torch.clamp(preview_output, 0, 1)
preview_image = transforms.ToPILImage()(preview_output.squeeze(0))
progress.append(f"✓ Style '{style_name}' extracted and saved!")
return "\n".join(progress), preview_image
except Exception as e:
import traceback
return None, f"Error: {str(e)}\n\n{traceback.format_exc()}"
# ============================================================================
# Image Processing Functions
# ============================================================================
def preprocess_image(img: Image.Image) -> torch.Tensor:
"""Convert PIL Image to tensor [0, 1]."""
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor()])
return transform(img).unsqueeze(0)
def postprocess_tensor(tensor: torch.Tensor) -> Image.Image:
"""Convert tensor to PIL Image."""
import torchvision.transforms as transforms
if tensor.dim() == 4:
tensor = tensor.squeeze(0)
tensor = torch.clamp(tensor, 0, 1)
transform = transforms.ToPILImage()
return transform(tensor)
def create_side_by_side(img1: Image.Image, img2: Image.Image, style_name: str) -> Image.Image:
"""Create side-by-side comparison."""
if img1.size != img2.size:
img2 = img2.resize(img1.size, Image.LANCZOS)
w, h = img1.size
combined = Image.new('RGB', (w * 2 + 20, h + 70), 'white')
combined.paste(img1, (0, 70))
combined.paste(img2, (w + 20, 70))
draw = ImageDraw.Draw(combined)
try:
font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28)
font_label = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
except:
font_title = ImageFont.load_default()
font_label = ImageFont.load_default()
draw.text((w + 10, 20), f"Style: {style_name}", fill='#667eea', font=font_title)
draw.text((w // 2, 50), "Original", fill='#555', font=font_label, anchor='mm')
draw.text((w * 1.5 + 10, 50), "Stylized", fill='#555', font=font_label, anchor='mm')
return combined
def add_watermark(img: Image.Image, style_name: str) -> Image.Image:
"""Add subtle watermark for social sharing."""
result = img.copy()
draw = ImageDraw.Draw(result)
w, h = result.size
text = f"StyleForge • {style_name}"
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", int(w / 40))
except:
font = ImageFont.load_default()
bbox = draw.textbbox((0, 0), text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
overlay = Image.new('RGBA', (text_w + 20, text_h + 10), (0, 0, 0, 100))
result.paste(overlay, (w - text_w - 25, h - text_h - 15), overlay)
draw.text((w - text_w - 15, h - text_h - 10), text, fill=(255, 255, 255, 200), font=font)
return result
# Global state for webcam mode
class WebcamState:
def __init__(self):
self.is_active = False
self.current_style = 'candy'
self.current_backend = 'auto'
self.frame_count = 0
webcam_state = WebcamState()
# ============================================================================
# Chart Generation
# ============================================================================
def create_performance_chart() -> str:
"""Create real-time performance chart as HTML."""
if not PLOTLY_AVAILABLE:
return "### Chart Unavailable\n\nPlotly is not installed. Install with: `pip install plotly`"
data = perf_tracker.get_chart_data()
if not data or len(data.timestamps) < 2:
return "### Performance Chart\n\nRun some inferences to see the chart populate..."
# Color mapping for backends
colors = {
'cuda': '#10b981', # green
'pytorch': '#6366f1', # blue
'auto': '#8b5cf6', # purple
}
# Create scatter plot with color-coded backends
fig = go.Figure()
for backend in set(data.backends):
backend_times = []
backend_timestamps = []
for i, b in enumerate(data.backends):
if b == backend:
backend_times.append(data.times[i])
backend_timestamps.append(data.timestamps[i])
if backend_times:
fig.add_trace(go.Scatter(
x=backend_timestamps,
y=backend_times,
mode='lines+markers',
name=backend.upper(),
line=dict(color=colors[backend]),
marker=dict(size=8, color=colors[backend]),
connectgaps=True
))
fig.update_layout(
title="Inference Time Over Time",
xaxis_title="Time",
yaxis_title="Time (ms)",
hovermode='x unified',
height=400,
margin=dict(l=0, r=0, t=40, b=40)
)
# Convert to HTML
return fig.to_html(full_html=False, include_plotlyjs='cdn')
def create_benchmark_comparison(style: str) -> str:
"""Create detailed benchmark comparison chart."""
if not PLOTLY_AVAILABLE:
return "Install plotly for charts"
# Run quick benchmark
test_img = Image.new('RGB', (512, 512), color='red')
results = {}
# Test each backend
for backend_name, backend_key in [('PyTorch', 'pytorch'), ('CUDA Kernels', 'cuda')]:
try:
model = load_model(style, backend_key)
test_tensor = preprocess_image(test_img).to(get_device())
times = []
for _ in range(3):
start = time.perf_counter()
with torch.no_grad():
_ = model(test_tensor)
if get_device().type == 'cuda':
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000)
results[backend_name] = np.mean(times)
except Exception:
results[backend_name] = None
# Create bar chart
fig = go.Figure()
backends = []
times_list = []
colors_list = []
for name, time_val in results.items():
if time_val:
backends.append(name)
times_list.append(time_val)
colors_list.append('#10b981' if 'CUDA' in name else '#6366f1')
if backends:
fig.add_trace(go.Bar(
x=backends,
y=times_list,
marker=dict(color=colors_list),
text=[f"{t:.1f} ms" for t in times_list],
textposition='outside',
))
fig.update_layout(
title=f"Benchmark Comparison - {STYLES.get(style, style.title())} Style",
xaxis_title="Backend",
yaxis_title="Inference Time (ms)",
height=400,
margin=dict(l=0, r=0, t=40, b=40),
showlegend=False
)
# Calculate speedup
if len(times_list) == 2:
speedup = times_list[1] / times_list[0] if times_list[0] > 0 else times_list[0] / times_list[1]
max_val = max(times_list)
min_val = min(times_list)
actual_speedup = max_val / min_val
caption = f"Speedup: **{actual_speedup:.2f}x**"
else:
caption = "Run on GPU with CUDA for comparison"
return fig.to_html(full_html=False, include_plotlyjs='cdn') + f"\n\n### {caption}"
# ============================================================================
# Gradio Interface Functions
# ============================================================================
def stylize_image_impl(
input_image: Optional[Image.Image],
style: str,
backend: str,
intensity: float,
show_comparison: bool,
add_watermark: bool
) -> Tuple[Optional[Image.Image], str, Optional[str]]:
"""Main stylization function for Gradio."""
if input_image is None:
return None, "Please upload an image first.", None
try:
# Convert to RGB if needed
if input_image.mode != 'RGB':
input_image = input_image.convert('RGB')
# Handle blended styles (format: "style1_style2_alpha")
if '_' in style and style not in STYLES:
parts = style.split('_')
if len(parts) >= 3:
style1, style2 = parts[0], parts[1]
alpha = float(parts[2]) / 100
model = get_blended_model(style1, style2, alpha, backend)
style_display = f"{STYLES.get(style1, style1)} × {alpha:.0%} + {STYLES.get(style2, style2)} × {100-alpha:.0%}"
else:
model = load_model(style, backend)
style_display = STYLES.get(style, style)
else:
model = load_model(style, backend)
style_display = STYLES.get(style, style)
# Preprocess
input_tensor = preprocess_image(input_image).to(get_device())
# Stylize with timing
start = time.perf_counter()
with torch.no_grad():
output_tensor = model(input_tensor)
if get_device().type == 'cuda':
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1000
# Determine actual backend used
actual_backend = 'cuda' if (backend == 'cuda' or (backend == 'auto' and CUDA_KERNELS_AVAILABLE)) else 'pytorch'
perf_tracker.record(elapsed_ms, actual_backend)
# Postprocess
output_image = postprocess_tensor(output_tensor.cpu())
# Apply intensity blending (blend original with stylized output)
# intensity 0-100, where 100 = full style, 0 = original
intensity_factor = intensity / 100.0
if intensity_factor < 1.0:
import torchvision.transforms as transforms
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()
original_tensor = to_tensor(input_image)
output_tensor_pil = to_tensor(output_image)
# Blend: output * intensity + original * (1 - intensity)
blended_tensor = output_tensor_pil * intensity_factor + original_tensor * (1 - intensity_factor)
blended_tensor = torch.clamp(blended_tensor, 0, 1)
output_image = to_pil(blended_tensor)
# Add watermark if requested
if add_watermark:
output_image = add_watermark(output_image, style_display)
# Create comparison if requested
if show_comparison:
output_image = create_side_by_side(input_image, output_image, style_display)
# Save for download
download_path = f"/tmp/styleforge_{int(time.time())}.png"
output_image.save(download_path, quality=95)
# Generate stats
stats = perf_tracker.get_stats()
fps = 1000 / elapsed_ms if elapsed_ms > 0 else 0
width, height = input_image.size
# Backend display name
backend_display = {
'auto': f"Auto ({'CUDA' if CUDA_KERNELS_AVAILABLE else 'PyTorch'})",
'cuda': 'CUDA Kernels',
'pytorch': 'PyTorch'
}.get(backend, backend)
stats_text = f"""
### Performance
| Metric | Value |
|--------|-------|
| **Style** | {style_display} |
| **Backend** | {backend_display} |
| **Time** | {elapsed_ms:.1f} ms ({fps:.0f} FPS) |
| **Avg Time** | {(stats.avg_ms if stats else elapsed_ms):.1f} ms |
| **Total Images** | {stats.total_inferences if stats else 1} |
| **Size** | {width}x{height} |
| **Device** | {get_device().type.upper()} |
---
{perf_tracker.get_comparison()}
"""
return output_image, stats_text, download_path
except Exception as e:
import traceback
error_details = traceback.format_exc()
error_msg = f"""
### Error
**{str(e)}**
<details>
<summary>Show details</summary>
```
{error_details}
```
</details>
"""
return None, error_msg, None
# Wrap with GPU decorator for ZeroGPU if available
# ZeroGPU requires ALL GPU-using functions to be decorated with @GPU
# Note: This must come AFTER the function definitions below
if SPACES_AVAILABLE:
try:
stylize_image = GPU(stylize_image_impl)
train_custom_style = GPU(train_custom_style_impl)
extract_style_from_image = GPU(extract_style_from_image_impl)
# apply_region_style = GPU(apply_region_style_impl)
# apply_region_style_ui = GPU(apply_region_style_ui_impl)
# create_style_blend_output will be wrapped after function definition
except Exception:
# Fallback if GPU decorator fails
stylize_image = stylize_image_impl
train_custom_style = train_custom_style_impl
extract_style_from_image = extract_style_from_image_impl
# create_style_blend_output = create_style_blend_output_impl
# apply_region_style = apply_region_style_impl
# apply_region_style_ui = apply_region_style_ui_impl
else:
stylize_image = stylize_image_impl
train_custom_style = train_custom_style_impl
extract_style_from_image = extract_style_from_image_impl
# create_style_blend_output = create_style_blend_output_impl
# apply_region_style = apply_region_style_impl
# apply_region_style_ui = apply_region_style_ui_impl
def process_webcam_frame(image: Image.Image, style: str, backend: str, intensity: float = 70) -> Image.Image:
"""Process webcam frame in real-time."""
if image is None:
return image
try:
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize for faster processing
if max(image.size) > 640:
scale = 640 / max(image.size)
new_size = (int(image.width * scale), int(image.height * scale))
image = image.resize(new_size, Image.LANCZOS)
# Use blended style if applicable
if '_' in style and style not in STYLES:
parts = style.split('_')
if len(parts) >= 3:
style1, style2 = parts[0], parts[1]
alpha = float(parts[2]) / 100
model = get_blended_model(style1, style2, alpha, backend)
else:
model = load_model(style, backend)
else:
model = load_model(style, backend)
input_tensor = preprocess_image(image).to(get_device())
with torch.no_grad():
output_tensor = model(input_tensor)
if get_device().type == 'cuda':
torch.cuda.synchronize()
output_image = postprocess_tensor(output_tensor.cpu())
# Apply intensity blending
intensity_factor = intensity / 100.0
if intensity_factor < 1.0:
import torchvision.transforms as transforms
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()
original_tensor = to_tensor(image)
output_tensor_pil = to_tensor(output_image)
blended_tensor = output_tensor_pil * intensity_factor + original_tensor * (1 - intensity_factor)
blended_tensor = torch.clamp(blended_tensor, 0, 1)
output_image = to_pil(blended_tensor)
webcam_state.frame_count += 1
actual_backend = 'cuda' if backend == 'cuda' or (backend == 'auto' and CUDA_KERNELS_AVAILABLE) else 'pytorch'
perf_tracker.record(10, actual_backend)
return output_image
except Exception as e:
print(f"Webcam processing error: {e}")
import traceback
traceback.print_exc()
return image
def apply_region_style_ui_impl(
input_image: Image.Image,
mask_type: str,
position: float,
style1: str,
style2: str,
backend: str
) -> Tuple[Image.Image, Image.Image]:
"""Apply region-based style transfer with AI segmentation support."""
if input_image is None:
return None, None
# Create mask based on type
if mask_type == "AI: Foreground":
try:
mask = get_ai_segmentation_mask(input_image, "foreground")
except Exception as e:
# Fallback to center circle if AI fails
print(f"AI segmentation failed: {e}, using fallback")
mask = create_region_mask(input_image, "center_circle", position)
elif mask_type == "AI: Background":
try:
mask = get_ai_segmentation_mask(input_image, "background")
except Exception as e:
# Fallback to horizontal split if AI fails
print(f"AI segmentation failed: {e}, using fallback")
mask = create_region_mask(input_image, "horizontal_split", position)
else:
# Convert display name to internal name
mask_type_map = {
"Horizontal Split": "horizontal_split",
"Vertical Split": "vertical_split",
"Center Circle": "center_circle",
"Corner Box": "corner_box",
"Full": "full"
}
internal_type = mask_type_map.get(mask_type, "horizontal_split")
mask = create_region_mask(input_image, internal_type, position)
# Apply styles
result = apply_region_style(input_image, mask, style1, style2, backend)
# Create mask overlay for visualization
mask_vis = mask.convert('RGB')
mask_vis = mask_vis.resize(input_image.size)
# Blend mask with original for visibility
orig_np = np.array(input_image)
mask_np = np.array(mask_vis)
overlay_np = (orig_np * 0.7 + mask_np * 0.3).astype(np.uint8)
mask_overlay = Image.fromarray(overlay_np)
return result, mask_overlay
def refresh_styles_list() -> list:
"""Refresh styles list including custom styles."""
custom = get_custom_styles()
return list(STYLES.keys()) + custom
def get_style_description(style: str) -> str:
"""Get description for selected style."""
return STYLE_DESCRIPTIONS.get(style, "")
def get_performance_stats() -> str:
"""Get current performance statistics."""
stats = perf_tracker.get_stats()
if not stats:
return "No data yet."
return f"""
### Live Statistics
| Metric | Value |
|--------|-------|
| **Avg Time** | {stats.avg_ms:.1f} ms |
| **Fastest** | {stats.min_ms:.1f} ms |
| **Slowest** | {stats.max_ms:.1f} ms |
| **Total Images** | {stats.total_inferences} |
| **Uptime** | {stats.uptime_hours:.1f} hours |
---
{perf_tracker.get_comparison()}
"""
def run_backend_comparison(style: str) -> str:
"""Run backend comparison and return results."""
if not CUDA_KERNELS_AVAILABLE:
return "### Backend Comparison\n\nCUDA kernels are not available on this device. Using PyTorch backend only."
# Create test image
test_img = Image.new('RGB', (512, 512), color='red')
results = {}
# Test PyTorch backend
try:
model = load_model(style, 'pytorch')
test_tensor = preprocess_image(test_img).to(get_device())
times = []
for _ in range(5):
start = time.perf_counter()
with torch.no_grad():
_ = model(test_tensor)
if get_device().type == 'cuda':
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000)
results['pytorch'] = np.mean(times[1:])
except Exception:
results['pytorch'] = None
# Test CUDA backend
try:
model = load_model(style, 'cuda')
test_tensor = preprocess_image(test_img).to(get_device())
times = []
for _ in range(5):
start = time.perf_counter()
with torch.no_grad():
_ = model(test_tensor)
if get_device().type == 'cuda':
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000)
results['cuda'] = np.mean(times[1:])
except Exception:
results['cuda'] = None
# Format results
output = "### Backend Comparison Results\n\n"
if results.get('pytorch') and results.get('cuda'):
speedup = results['pytorch'] / results['cuda']
output += f"""
| Backend | Time | Speedup |
|---------|------|---------|
| **PyTorch** | {results['pytorch']:.1f} ms | 1.0x |
| **CUDA Kernels** | {results['cuda']:.1f} ms | {speedup:.2f}x |
### CUDA kernels are {speedup:.1f}x faster! 🚀
"""
else:
output += "Could not complete comparison. Both backends may not be available."
return output
def create_style_blend_output_impl(
input_image: Image.Image,
style1: str,
style2: str,
blend_ratio: float,
backend: str
) -> Image.Image:
"""Create blended style output."""
if input_image is None:
return None
# Convert to RGB
if input_image.mode != 'RGB':
input_image = input_image.convert('RGB')
# Get blended model
alpha = blend_ratio / 100
model = get_blended_model(style1, style2, alpha, backend)
# Process with timing
input_tensor = preprocess_image(input_image).to(get_device())
start = time.perf_counter()
with torch.no_grad():
output_tensor = model(input_tensor)
if get_device().type == 'cuda':
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1000
# Record performance
actual_backend = 'cuda' if (backend == 'cuda' or (backend == 'auto' and CUDA_KERNELS_AVAILABLE)) else 'pytorch'
perf_tracker.record(elapsed_ms, actual_backend)
output_image = postprocess_tensor(output_tensor.cpu())
return output_image
# ============================================================================
# Complete GPU wrapping for functions defined above
# ============================================================================
# Now wrap the remaining functions that were defined after the initial GPU wrapping
if SPACES_AVAILABLE:
try:
create_style_blend_output = GPU(create_style_blend_output_impl)
apply_region_style = GPU(apply_region_style_impl)
apply_region_style_ui = GPU(apply_region_style_ui_impl)
process_webcam_frame = GPU(process_webcam_frame)
except Exception:
create_style_blend_output = create_style_blend_output_impl
apply_region_style = apply_region_style_impl
apply_region_style_ui = apply_region_style_ui_impl
process_webcam_frame = process_webcam_frame
else:
create_style_blend_output = create_style_blend_output_impl
apply_region_style = apply_region_style_impl
apply_region_style_ui = apply_region_style_ui_impl
process_webcam_frame = process_webcam_frame
# ============================================================================
# Build Gradio Interface
# ============================================================================
custom_css = """
/* ============================================
LIQUID GLASS / GLASSMORPHISM THEME
Gradio 5.x Compatible
============================================ */
/* Import Google Fonts - Plus Jakarta Sans for premium Framer-style look */
@import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@200;300;400;500;600;700;800&display=swap');
/* Animated gradient background */
body {
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
background-size: 400% 400%;
animation: gradientBG 15s ease infinite;
font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif;
min-height: 100vh;
}
@keyframes gradientBG {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
/* Universal font application */
* {
font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif;
}
/* Ensure text elements are visible */
h1, h2, h3, h4, h5, h6, p, span, div, label, button, input, textarea, select {
color: inherit;
}
/* Main app container - glass effect */
.gradio-container {
backdrop-filter: blur(20px) saturate(180%);
-webkit-backdrop-filter: blur(20px) saturate(180%);
background: rgba(255, 255, 255, 0.75) !important;
border-radius: 24px;
border: 1px solid rgba(255, 255, 255, 0.3);
box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.15);
max-width: 1400px;
margin: 20px auto;
padding: 24px !important;
}
/* Primary button - enhanced glass effect with Framer-style typography */
button.primary,
.gr-button-primary,
[class*="primary"] {
background: linear-gradient(135deg,
rgba(255, 255, 255, 0.25) 0%,
rgba(255, 255, 255, 0.1) 50%,
rgba(255, 255, 255, 0.05) 100%) !important;
backdrop-filter: blur(20px) saturate(180%);
-webkit-backdrop-filter: blur(20px) saturate(180%);
border: 1px solid rgba(255, 255, 255, 0.4) !important;
color: #4F46E5 !important;
font-family: 'Plus Jakarta Sans', sans-serif !important;
font-weight: 600 !important;
letter-spacing: -0.01em !important;
border-radius: 16px !important;
padding: 20px 32px !important;
transition: all 0.3s ease !important;
box-shadow:
0 8px 32px rgba(31, 38, 135, 0.15),
inset 0 1px 0 rgba(255, 255, 255, 0.5),
inset 0 -1px 0 rgba(0, 0, 0, 0.05) !important;
position: relative;
overflow: hidden;
width: 100% !important;
min-height: 64px !important;
margin: 8px 0 !important;
}
button.primary::before,
.gr-button-primary::before,
[class*="primary"]::before {
content: '';
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg,
transparent,
rgba(255, 255, 255, 0.4),
transparent);
transition: left 0.5s ease;
}
button.primary:hover::before,
.gr-button-primary:hover::before,
[class*="primary"]:hover::before {
left: 100%;
}
button.primary:hover,
.gr-button-primary:hover,
[class*="primary"]:hover {
transform: translateY(-2px);
background: linear-gradient(135deg,
rgba(255, 255, 255, 0.35) 0%,
rgba(255, 255, 255, 0.15) 50%,
rgba(255, 255, 255, 0.1) 100%) !important;
border-color: rgba(255, 255, 255, 0.6) !important;
box-shadow:
0 12px 40px rgba(31, 38, 135, 0.25),
inset 0 1px 0 rgba(255, 255, 255, 0.6),
inset 0 -1px 0 rgba(0, 0, 0, 0.05) !important;
}
button.primary:active,
.gr-button-primary:active,
[class*="primary"]:active {
transform: translateY(0);
box-shadow:
0 4px 20px rgba(31, 38, 135, 0.15),
inset 0 1px 0 rgba(255, 255, 255, 0.3) !important;
}
/* Secondary button - enhanced glass style with Framer typography */
button.secondary,
.gr-button-secondary,
.download,
[class*="secondary"] {
background: linear-gradient(135deg,
rgba(255, 255, 255, 0.4) 0%,
rgba(255, 255, 255, 0.2) 100%) !important;
backdrop-filter: blur(20px) saturate(180%);
-webkit-backdrop-filter: blur(20px) saturate(180%);
border: 1px solid rgba(255, 255, 255, 0.5) !important;
color: #374151 !important;
font-family: 'Plus Jakarta Sans', sans-serif !important;
border-radius: 14px !important;
padding: 16px 28px !important;
transition: all 0.3s ease !important;
font-weight: 500 !important;
letter-spacing: -0.005em !important;
box-shadow:
0 8px 32px rgba(31, 38, 135, 0.12),
inset 0 1px 0 rgba(255, 255, 255, 0.6),
inset 0 -1px 0 rgba(0, 0, 0, 0.03) !important;
width: 100% !important;
min-height: 56px !important;
margin: 6px 0 !important;
}
button.secondary:hover,
.gr-button-secondary:hover,
.download:hover,
[class*="secondary"]:hover {
background: linear-gradient(135deg,
rgba(255, 255, 255, 0.6) 0%,
rgba(255, 255, 255, 0.3) 100%) !important;
border-color: rgba(255, 255, 255, 0.7) !important;
box-shadow:
0 12px 40px rgba(31, 38, 135, 0.18),
inset 0 1px 0 rgba(255, 255, 255, 0.7) !important;
transform: translateY(-1px);
}
/* All buttons - Framer-style rounded corners and typography with spacing */
button,
.gr-button {
font-family: 'Plus Jakarta Sans', sans-serif !important;
border-radius: 14px !important;
transition: all 0.3s ease !important;
width: 100% !important;
min-height: 52px !important;
padding: 14px 24px !important;
margin: 4px 0 !important;
}
/* Button containers - ensure buttons fill row with spacing */
.gradio-container button,
#quick_stylize button,
#style_blending button,
#region_transfer button,
#custom_training button,
#benchmarking button {
width: 100% !important;
margin: 6px 0 !important;
}
/* Tabs - glass style */
.tabs {
background: rgba(255, 255, 255, 0.4) !important;
backdrop-filter: blur(10px);
border-radius: 16px !important;
padding: 8px !important;
border: 1px solid rgba(255, 255, 255, 0.3) !important;
}
/* Tab buttons */
button.tab-item {
background: transparent !important;
border-radius: 12px !important;
color: #6B7280 !important;
transition: all 0.3s ease !important;
}
button.tab-item:hover {
background: rgba(255, 255, 255, 0.5) !important;
}
button.tab-item.selected {
background: rgba(255, 255, 255, 0.8) !important;
color: #6366F1 !important;
font-weight: 600 !important;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
}
/* Input boxes and text areas */
input[type="text"],
input[type="number"],
textarea,
select {
background: rgba(255, 255, 255, 0.7) !important;
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.5) !important;
border-radius: 12px !important;
transition: all 0.3s ease !important;
}
input[type="text"]:focus,
input[type="number"]:focus,
textarea:focus,
select:focus {
background: rgba(255, 255, 255, 0.9) !important;
border-color: #6366F1 !important;
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1) !important;
outline: none !important;
}
/* Dropdown menu fixes */
.gradio-dropdown,
.dropdown,
[class*="dropdown"],
[class*="Dropdown"] {
position: relative !important;
z-index: 100 !important;
}
/* Dropdown menu/option list positioning */
.gradio-dropdown ul,
.dropdown ul,
select option,
[class*="dropdown"] ul,
[class*="Dropdown"] ul {
position: absolute !important;
z-index: 9999 !important;
background: rgba(255, 255, 255, 0.95) !important;
backdrop-filter: blur(20px) !important;
border: 1px solid rgba(255, 255, 255, 0.5) !important;
border-radius: 12px !important;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.15) !important;
max-height: 300px !important;
overflow-y: auto !important;
}
/* Gradio 5.x dropdown container fixes */
[class*="svelte"][class*="dropdown"],
[class*="svelte"][class*="radio"],
[class*="svelte"][class*="radio-item"] {
position: relative !important;
z-index: 100 !important;
}
/* Radio button group positioning */
.gradio-radio,
[class*="radio"] {
position: relative !important;
z-index: 50 !important;
}
/* Image containers - glass frame */
.image-container,
[class*="image"] {
border-radius: 16px !important;
border: 1px solid rgba(255, 255, 255, 0.3) !important;
overflow: hidden !important;
background: rgba(255, 255, 255, 0.3) !important;
}
/* Slider styling */
input[type="range"] {
-webkit-appearance: none;
background: rgba(229, 231, 235, 0.6);
backdrop-filter: blur(10px);
border-radius: 8px;
height: 8px;
border: 1px solid rgba(255, 255, 255, 0.3);
}
input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none;
width: 22px;
height: 22px;
background: linear-gradient(135deg, #6366F1 0%, #8B5CF6 100%);
border: 3px solid white;
border-radius: 50%;
cursor: pointer;
box-shadow: 0 2px 8px rgba(99, 102, 241, 0.4);
}
input[type="range"]::-moz-range-thumb {
width: 22px;
height: 22px;
background: linear-gradient(135deg, #6366F1 0%, #8B5CF6 100%);
border: 3px solid white;
border-radius: 50%;
cursor: pointer;
box-shadow: 0 2px 8px rgba(99, 102, 241, 0.4);
}
/* Checkbox and radio styling */
input[type="checkbox"],
input[type="radio"] {
accent-color: #6366F1 !important;
width: 18px !important;
height: 18px !important;
}
/* Badge styles */
.live-badge {
display: inline-block;
padding: 6px 16px;
background: rgba(254, 243, 199, 0.8);
backdrop-filter: blur(10px);
color: #92400E;
border-radius: 24px;
font-size: 13px;
font-weight: 600;
border: 1px solid rgba(255, 255, 255, 0.3);
}
.backend-badge {
display: inline-block;
padding: 6px 16px;
background: rgba(209, 250, 229, 0.8);
backdrop-filter: blur(10px);
color: #065F46;
border-radius: 24px;
font-size: 13px;
font-weight: 600;
border: 1px solid rgba(255, 255, 255, 0.3);
}
/* Markdown content with Framer-style typography */
.markdown {
color: #374151 !important;
font-family: 'Plus Jakarta Sans', sans-serif !important;
}
/* Headings with Framer-style typography - tighter letter spacing */
.gradio-container h1,
.gradio-container h2,
.gradio-container h3,
.gradio-container h4,
.gradio-container h5,
.gradio-container h6 {
font-family: 'Plus Jakarta Sans', sans-serif !important;
color: #1F2937 !important;
letter-spacing: -0.025em !important;
font-weight: 600 !important;
}
/* Text visibility fixes */
.gradio-container,
.gradio-container *,
.gradio-container p,
.gradio-container span,
.gradio-container label {
color: #1F2937 !important;
}
/* Button text colors */
button,
.gradio-container button {
color: inherit !important;
}
/* Input and select text colors with Framer-style typography */
input,
textarea,
select {
color: #1F2937 !important;
font-family: 'Plus Jakarta Sans', sans-serif !important;
}
/* Label colors with Framer-style typography */
label,
[class*="label"] {
color: #374151 !important;
font-family: 'Plus Jakarta Sans', sans-serif !important;
font-weight: 500 !important;
letter-spacing: -0.01em !important;
}
/* Gradio 5.x specific text elements */
.svelte-*, [class*="svelte-"] {
font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
}
/* Group/Row/Column containers */
.group,
.row,
.column {
background: rgba(255, 255, 255, 0.3) !important;
border-radius: 16px !important;
padding: 16px !important;
}
/* Accordion */
.details {
background: rgba(255, 255, 255, 0.4) !important;
backdrop-filter: blur(10px);
border-radius: 16px !important;
border: 1px solid rgba(255, 255, 255, 0.3) !important;
}
/* Scrollbar - glass style */
::-webkit-scrollbar {
width: 10px;
height: 10px;
}
::-webkit-scrollbar-track {
background: rgba(229, 231, 235, 0.3);
border-radius: 8px;
}
::-webkit-scrollbar-thumb {
background: rgba(167, 139, 250, 0.5);
border-radius: 8px;
border: 2px solid rgba(255, 255, 255, 0.3);
}
::-webkit-scrollbar-thumb:hover {
background: rgba(139, 92, 246, 0.7);
}
/* Progress bar */
progress {
background: rgba(229, 231, 235, 0.5) !important;
border-radius: 8px !important;
height: 8px !important;
}
progress::-webkit-progress-bar {
background: rgba(229, 231, 235, 0.5);
border-radius: 8px;
}
progress::-webkit-progress-value {
background: linear-gradient(90deg, #6366F1, #8B5CF6) !important;
border-radius: 8px;
}
/* Mobile responsive */
@media (max-width: 768px) {
.gradio-container {
margin: 10px !important;
padding: 16px !important;
border-radius: 20px !important;
}
button.primary,
.gr-button-primary,
[class*="primary"] {
padding: 10px 18px !important;
font-size: 14px !important;
}
}
/* Loading spinner */
.spinner {
border: 3px solid rgba(99, 102, 241, 0.2);
border-top: 3px solid #6366F1;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* Additional Gradio 5.x specific selectors */
.gradio-button.primary,
button[class*="Primary"],
[type="button"].primary {
background: linear-gradient(135deg, rgba(99, 102, 241, 0.9) 0%, rgba(139, 92, 246, 0.9) 100%) !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
color: white !important;
font-weight: 600 !important;
border-radius: 16px !important;
box-shadow: 0 4px 15px rgba(99, 102, 241, 0.25) !important;
min-height: 64px !important;
padding: 20px 32px !important;
margin: 8px 0 !important;
}
/* Gradio 5.x Portal/Popover fixes for dropdowns - simplified */
body > div[id*="portal"] {
position: fixed !important;
z-index: 99999 !important;
margin-top: 0 !important;
padding-top: 0 !important;
}
/* Block containers */
.block {
background: rgba(255, 255, 255, 0.25) !important;
border-radius: 16px !important;
padding: 12px !important;
}
/* Form elements */
.form,
.form-group {
background: transparent !important;
}
"""
with gr.Blocks(
title="StyleForge: Neural Style Transfer",
theme=gr.themes.Glass(
primary_hue="indigo",
secondary_hue="purple",
font=gr.themes.GoogleFont("Plus Jakarta Sans"),
radius_size="lg",
),
css=custom_css,
) as demo:
# Load Google Fonts via HTML head injection
gr.HTML("""
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@200;300;400;500;600;700;800&display=swap" rel="stylesheet">
""")
# Header with Portal-style hero section
cuda_badge = f"<span class='backend-badge'>CUDA Accelerated</span>" if CUDA_KERNELS_AVAILABLE else ""
gr.HTML(f"""
<div style="text-align: center; padding: 3rem 0 2rem 0; font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;">
<h1 style="font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; font-size: 3.5rem; margin-bottom: 0.5rem; background: linear-gradient(135deg, #6366F1, #8B5CF6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-weight: 700; letter-spacing: -0.02em;">
StyleForge
</h1>
<p style="font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; color: #6B7280; font-size: 1.15rem; margin-bottom: 1rem; font-weight: 400;">
Neural Style Transfer with Custom CUDA Kernels
</p>
{cuda_badge}
<p style="font-family: 'Plus Jakarta Sans', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; color: #9CA3AF; margin-top: 1rem; font-size: 0.95rem; font-weight: 300;">
Custom Styles • Region Transfer • Style Blending • Real-time Processing
</p>
</div>
""")
# Mode selector
with gr.Tabs() as tabs:
# Tab 1: Quick Style Transfer
with gr.Tab("Quick Style", id=0):
with gr.Row():
with gr.Column(scale=1):
quick_image = gr.Image(
label="Upload Image",
type="pil",
sources=["upload", "clipboard"],
height=400
)
quick_style = gr.Dropdown(
choices=list(STYLES.keys()),
value='candy',
label="Artistic Style"
)
quick_backend = gr.Radio(
choices=list(BACKENDS.keys()),
value='auto',
label="Processing Backend"
)
quick_intensity = gr.Slider(
minimum=0,
maximum=100,
value=70,
step=5,
label="Style Intensity (0% = Original, 100% = Full Style)"
)
with gr.Row():
quick_compare = gr.Checkbox(
label="Side-by-side",
value=False
)
quick_watermark = gr.Checkbox(
label="Add watermark",
value=False
)
quick_btn = gr.Button(
"Stylize Image",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
quick_output = gr.Image(
label="Result",
type="pil",
height=400
)
with gr.Row():
quick_download = gr.DownloadButton(
label="Download",
variant="secondary"
)
quick_stats = gr.Markdown(
"> Upload an image and click **Stylize** to begin!"
)
# Tab 2: Style Blending
with gr.Tab("Style Blending", id=1):
gr.Markdown("""
### Mix Two Styles Together
Blend between any two styles to create unique artistic combinations.
This demonstrates style interpolation in the latent space.
""")
with gr.Row():
with gr.Column(scale=1):
blend_image = gr.Image(
label="Upload Image",
type="pil",
sources=["upload", "clipboard"],
height=350
)
blend_style1 = gr.Dropdown(
choices=list(STYLES.keys()),
value='candy',
label="Style 1"
)
blend_style2 = gr.Dropdown(
choices=list(STYLES.keys()),
value='mosaic',
label="Style 2"
)
blend_ratio = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=5,
label="Blend Ratio"
)
blend_backend = gr.Radio(
choices=list(BACKENDS.keys()),
value='auto',
label="Backend"
)
blend_btn = gr.Button(
"Blend Styles",
variant="primary"
)
gr.Markdown("""
**How it Works:**
- Style blending interpolates between model weights
- At 0% you get pure Style 1
- At 100% you get pure Style 2
- At 50% you get an equal mix of both
""")
with gr.Column(scale=1):
blend_output = gr.Image(
label="Blended Result",
type="pil",
height=350
)
blend_info = gr.Markdown(
"Adjust the blend ratio and click **Blend Styles** to see the result."
)
# Tab 3: Region-Based Style
with gr.Tab("Region Transfer", id=2):
gr.Markdown("""
### Apply Different Styles to Different Regions
Transform specific parts of your image with different styles.
**NEW:** AI-powered foreground/background segmentation!
""")
with gr.Row():
with gr.Column(scale=1):
region_image = gr.Image(
label="Upload Image",
type="pil",
sources=["upload", "clipboard"],
height=350
)
region_mask_type = gr.Radio(
choices=[
"AI: Foreground",
"AI: Background",
"Horizontal Split",
"Vertical Split",
"Center Circle",
"Corner Box",
"Full"
],
value="AI: Foreground",
label="Mask Type"
)
region_position = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Split Position"
)
with gr.Row():
region_style1 = gr.Dropdown(
choices=list(STYLES.keys()),
value='candy',
label="Style (White/Top/Left)"
)
region_style2 = gr.Dropdown(
choices=list(STYLES.keys()),
value='mosaic',
label="Style (Black/Bottom/Right)"
)
region_backend = gr.Radio(
choices=list(BACKENDS.keys()),
value='auto',
label="Backend"
)
region_btn = gr.Button(
"Apply Region Styles",
variant="primary"
)
with gr.Column(scale=1):
with gr.Tabs():
with gr.Tab("Result"):
region_output = gr.Image(
label="Stylized Result",
type="pil",
height=300
)
with gr.Tab("Mask Preview"):
region_mask_preview = gr.Image(
label="Mask Preview",
type="pil",
height=300
)
gr.Markdown("""
**Mask Guide:**
- **AI: Foreground** 🆕: Automatically detect main subject (person, object, etc.)
- **AI: Background** 🆕: Automatically detect background/sky
- **Horizontal**: Top/bottom split
- **Vertical**: Left/right split
- **Center Circle**: Circular region in center
- **Corner Box**: Top-left quadrant only
*AI segmentation uses the Rembg model (U^2-Net) for automatic subject detection.*
""")
# Tab 4: Custom Style Training
with gr.Tab("Create Style", id=3):
gr.Markdown("""
### Extract Style from Any Image 🆕
Upload any artwork to extract its artistic style using **VGG19 feature matching**.
**How it works:**
1. Extract style features using pre-trained VGG19 neural network
2. Fine-tune a transformation network to match those features
3. Save as a reusable style model
This is **real style extraction** - not just copying an existing style!
""")
with gr.Row():
with gr.Column(scale=1):
train_style_image = gr.Image(
label="Style Image (Artwork)",
type="pil",
sources=["upload"],
height=350
)
train_style_name = gr.Textbox(
label="Style Name",
value="my_custom_style",
placeholder="Enter a name for your custom style"
)
train_iterations = gr.Slider(
minimum=50,
maximum=500,
value=100,
step=50,
label="Training Iterations"
)
train_backend = gr.Radio(
choices=list(BACKENDS.keys()),
value='auto',
label="Backend"
)
train_btn = gr.Button(
"Extract Style",
variant="primary"
)
refresh_styles_btn = gr.Button("Refresh Style List")
with gr.Column(scale=1):
train_output = gr.Markdown(
"> Upload a style image and click **Extract Style** to begin!\n\n"
"**How it works:**\n"
"- VGG19 extracts artistic features (textures, colors, patterns)\n"
"- Neural network is fine-tuned to match those features\n"
"- Result is a reusable style model\n\n"
"**Tips:**\n"
"- Use artwork with clear artistic style (paintings, illustrations)\n"
"- More iterations = better style matching (slower)\n"
"- GPU recommended for faster training\n"
"- Your custom style will appear in all Style dropdowns"
)
train_progress = gr.Markdown("")
# Tab 5: Webcam Live
with gr.Tab("Webcam Live", id=4):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""
### <span class="live-badge">LIVE</span> Real-time Webcam Style Transfer
""")
webcam_style = gr.Dropdown(
choices=list(STYLES.keys()),
value='candy',
label="Artistic Style"
)
webcam_backend = gr.Radio(
choices=list(BACKENDS.keys()),
value='auto',
label="Backend"
)
webcam_intensity = gr.Slider(
minimum=0,
maximum=100,
value=70,
step=5,
label="Style Intensity (0% = Original, 100% = Full Style)"
)
webcam_stream = gr.Image(
sources=["webcam"],
label="Webcam Feed",
height=400
)
webcam_info = gr.Markdown(
"> Click in the webcam preview to start the feed"
)
with gr.Column(scale=1):
webcam_output = gr.Image(
label="Stylized Output",
height=400
)
webcam_stats = gr.Markdown(
get_performance_stats()
)
refresh_stats_btn = gr.Button("Refresh Stats", size="sm")
# Tab 6: Performance Dashboard
with gr.Tab("Performance", id=5):
gr.Markdown("""
### Real-time Performance Dashboard
Track inference times and compare backends with live charts.
""")
with gr.Row():
benchmark_style = gr.Dropdown(
choices=list(STYLES.keys()),
value='candy',
label="Select Style for Benchmark"
)
run_benchmark_btn = gr.Button(
"Run Benchmark",
variant="primary"
)
benchmark_chart = gr.Markdown(
"Click **Run Benchmark** to see the performance chart"
)
live_chart = gr.Markdown(
"Run some inferences to see the live chart populate below..."
)
refresh_chart_btn = gr.Button("Refresh Chart")
gr.Markdown("---")
gr.Markdown("### Live Performance Chart")
chart_display = gr.HTML(
"<div style='text-align:center; padding: 20px;'>Run inferences to see chart</div>"
)
chart_stats = gr.Markdown()
# Style description (shared across all tabs)
style_desc = gr.Markdown("*Select a style to see description*")
# Examples section
gr.Markdown("---")
def create_example_image():
# Create a more interesting test image with geometric shapes
arr = np.zeros((256, 256, 3), dtype=np.uint8)
# Background gradient
for i in range(256):
arr[:, i, 0] = i // 2
arr[:, i, 1] = 128
arr[:, i, 2] = 255 - i // 2
# Add a circle in the center
cy, cx = 128, 128
for y in range(256):
for x in range(256):
if (x - cx)**2 + (y - cy)**2 <= 50**2:
arr[y, x, 0] = 255
arr[y, x, 1] = 200
arr[y, x, 2] = 100
return Image.fromarray(arr)
example_img = create_example_image()
# Pre-styled example outputs for display
# These images demonstrate each style without needing to run the model
gr.Markdown("### Quick Style Examples")
gr.Markdown("Upload an image and select a style to see the transformation in action!")
# Examples removed - they all showed the same input image
# The style gallery below shows visual representations of each style
# Display example style gallery with unique gradients for each style
gr.Markdown("""
<div style="display: flex; gap: 0.8rem; justify-content: center; margin: 1rem 0; flex-wrap: wrap;">
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #ff6b6b, #feca57); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(255, 107, 107, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">🍬 Candy</p>
</div>
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #5f27cd, #00d2d3); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(95, 39, 205, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">🎨 Mosaic</p>
</div>
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #576574, #c8d6e5); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(87, 101, 116, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">🌧️ Rain Princess</p>
</div>
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #ee5a24, #f9ca24); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(238, 90, 36, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">🖼️ Udnie</p>
</div>
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #1e3a8a, #fbbf24, #3b82f6); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(30, 58, 138, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">🌃 Starry Night</p>
</div>
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #be185d, #fce7f3, #9d174d); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(190, 24, 93, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">🎭 La Muse</p>
</div>
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #dc2626, #1f2937, #f97316); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(220, 38, 38, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">😱 The Scream</p>
</div>
<div style="text-align: center;">
<div style="width: 100px; height: 100px; background: linear-gradient(135deg, #7c3aed, #2563eb, #db2777); border-radius: 12px; margin: 0 auto; box-shadow: 0 4px 15px rgba(124, 58, 237, 0.3);"></div>
<p style="margin-top: 0.5rem; font-size: 0.8rem; font-weight: 500;">🎪 Composition VII</p>
</div>
</div>
""")
# FAQ Section
gr.Markdown("---")
with gr.Accordion("FAQ & Help", open=False):
gr.Markdown("""
### What are CUDA kernels?
Custom CUDA kernels are hand-written GPU code that fuses multiple operations
into a single kernel launch. This reduces memory transfers and improves
performance by 8-9x.
### How does Style Blending work?
Style blending interpolates between the weights of two trained style models.
This demonstrates that styles exist in a continuous latent space where you can
navigate and create new artistic variations.
### What is Region-based Style Transfer?
This feature applies different artistic styles to different regions of the same image.
It demonstrates computer vision concepts like segmentation and masking, while
enabling creative effects like "make the sky look like Starry Night while keeping
the ground realistic."
### Which backend should I use?
- **Auto**: Recommended - automatically uses the fastest available option
- **CUDA Kernels**: Best performance on GPU (requires CUDA compilation)
- **PyTorch**: Fallback for CPU or when CUDA is unavailable
### Can I use this commercially?
Yes! StyleForge is open source (MIT license).
""")
# Technical details
with gr.Accordion("Technical Details", open=False):
gr.Markdown(f"""
### Architecture
**Network:** Encoder-Decoder with Residual Blocks (Johnson et al.)
- **Encoder**: 3 Conv layers + Instance Normalization
- **Transformer**: 5 Residual blocks
- **Decoder**: 3 Upsample Conv layers + Instance Normalization
### CUDA Optimizations
**Status:** {'✅ Available' if CUDA_KERNELS_AVAILABLE else '❌ Not Available (CPU or no CUDA)'}
When CUDA kernels are available:
- **Fused InstanceNorm**: Combines mean, variance, normalize, affine transform
- **Vectorized memory**: Uses `float4` loads for 4x bandwidth
- **Shared memory**: Reduces global memory traffic
- **Warp-level reductions**: Efficient parallel reductions
### ML Concepts Demonstrated
- **Style Transfer**: Neural artistic stylization
- **Latent Space Interpolation**: Style blending shows continuous style space
- **Conditional Generation**: Region-based style transfer
- **Transfer Learning**: Custom style training from few examples
- **Performance Optimization**: CUDA kernels, JIT compilation, caching
- **Model Deployment**: Gradio web interface, CI/CD pipeline
### Resources
- [GitHub Repository](https://github.com/olivialiau/StyleForge)
- [Paper: Perceptual Losses for Real-Time Style Transfer](https://arxiv.org/abs/1603.08155)
""")
# Footer
gr.Markdown("""
<div class="footer">
<p>
<strong>StyleForge</strong> • Created by Olivia • USC Computer Science<br>
<a href="https://github.com/olivialiau/StyleForge">GitHub</a> •
Built with <a href="https://huggingface.co/spaces">Hugging Face Spaces</a> 🤗
</p>
</div>
""")
# ============================================================================
# Event Handlers
# ============================================================================
# Style description updates
def update_style_desc(style):
desc = STYLE_DESCRIPTIONS.get(style, "")
return f"*{desc}*"
# Quick style handlers
quick_style.change(
fn=update_style_desc,
inputs=[quick_style],
outputs=[style_desc]
)
quick_btn.click(
fn=stylize_image,
inputs=[quick_image, quick_style, quick_backend, quick_intensity, quick_compare, quick_watermark],
outputs=[quick_output, quick_stats, quick_download]
)
# Style blending handlers
def update_blend_info(style1: str, style2: str, ratio: float) -> str:
s1_name = STYLES.get(style1, style1)
s2_name = STYLES.get(style2, style2)
return f"Blended {s1_name} × {ratio:.0f}% + {s2_name} × {100-ratio:.0f}%"
blend_btn.click(
fn=create_style_blend_output,
inputs=[blend_image, blend_style1, blend_style2, blend_ratio, blend_backend],
outputs=[blend_output]
).then(
fn=update_blend_info,
inputs=[blend_style1, blend_style2, blend_ratio],
outputs=[blend_info]
)
# Region-based handlers
region_btn.click(
fn=apply_region_style_ui,
inputs=[region_image, region_mask_type, region_position, region_style1, region_style2, region_backend],
outputs=[region_output, region_mask_preview]
)
region_mask_type.change(
fn=lambda mt, img, pos: create_region_mask(img, mt, pos) if img else None,
inputs=[region_mask_type, region_image, region_position],
outputs=[region_mask_preview]
)
region_position.change(
fn=lambda pos, img, mt: create_region_mask(img, mt, pos) if img else None,
inputs=[region_position, region_image, region_mask_type],
outputs=[region_mask_preview]
)
# Custom style training
train_btn.click(
fn=train_custom_style,
inputs=[train_style_image, train_style_name, train_iterations, train_backend],
outputs=[train_progress, train_output]
)
def update_style_choices():
return list(STYLES.keys()) + get_custom_styles()
refresh_styles_btn.click(
fn=update_style_choices,
outputs=[quick_style]
).then(
fn=update_style_choices,
outputs=[blend_style1]
).then(
fn=update_style_choices,
outputs=[blend_style2]
)
# Webcam handlers - note: streaming disabled for Gradio 5.x compatibility
# Users can still upload/process webcam images manually
webcam_stream.change(
fn=process_webcam_frame,
inputs=[webcam_stream, webcam_style, webcam_backend, webcam_intensity],
outputs=[webcam_output],
)
refresh_stats_btn.click(
fn=get_performance_stats,
outputs=[webcam_stats]
)
# Benchmark handlers
run_benchmark_btn.click(
fn=create_benchmark_comparison,
inputs=[benchmark_style],
outputs=[benchmark_chart]
)
refresh_chart_btn.click(
fn=create_performance_chart,
outputs=[chart_display]
)
# ============================================================================
# Launch Configuration
# ============================================================================
if __name__ == "__main__":
# Disable API to avoid gradio_client compatibility issues
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
demo.launch(
show_api=False,
show_error=True,
quiet=False,
)