dannyroxas commited on
Commit
bebc024
·
verified ·
1 Parent(s): 5a02e97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -2092
app.py CHANGED
@@ -1,426 +1,4 @@
1
- #!/usr/bin/env python3
2
- """
3
- STYLE TRANSFER APP - Streamlit Version with Regional Transformations
4
- All existing features preserved + new local painting capabilities + Unsplash integration
5
- """
6
-
7
- import os
8
- os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
9
- os.environ['TORCH_HOME'] = '/tmp/torch_cache'
10
- os.environ['HF_HOME'] = '/tmp/hf_cache'
11
- os.makedirs('/tmp/torch_cache', exist_ok=True)
12
- os.makedirs('/tmp/hf_cache', exist_ok=True)
13
-
14
- import streamlit as st
15
- from streamlit_drawable_canvas import st_canvas
16
- import torch
17
- import torch.nn as nn
18
- import torch.nn.functional as F
19
- import torchvision.transforms as transforms
20
- import torchvision.models as models
21
- from torch.utils.data import Dataset, DataLoader
22
- from PIL import Image, ImageDraw, ImageFont
23
- import numpy as np
24
- import glob
25
- import datetime
26
- import traceback
27
- import uuid
28
- import warnings
29
- import zipfile
30
- import io
31
- import json
32
- import time
33
- import shutil
34
- import requests
35
- try:
36
- import cv2
37
- VIDEO_PROCESSING_AVAILABLE = True
38
- except ImportError:
39
- VIDEO_PROCESSING_AVAILABLE = False
40
- print("OpenCV not available - video processing disabled")
41
- import tempfile
42
- from pathlib import Path
43
- import colorsys
44
- warnings.filterwarnings("ignore")
45
-
46
- # Set page config
47
- st.set_page_config(
48
- page_title="Style Transfer Studio",
49
- page_icon="🎨",
50
- layout="wide",
51
- initial_sidebar_state="expanded"
52
- )
53
-
54
- # Custom CSS for better UI
55
- st.markdown("""
56
- <style>
57
- .stTabs [data-baseweb="tab-list"] {
58
- gap: 24px;
59
- }
60
- .stTabs [data-baseweb="tab"] {
61
- height: 50px;
62
- padding-left: 20px;
63
- padding-right: 20px;
64
- }
65
- .main > div {
66
- padding-top: 2rem;
67
- }
68
- .st-emotion-cache-1y4p8pa {
69
- max-width: 100%;
70
- }
71
- /* Fix canvas container */
72
- .stDrawableCanvas {
73
- margin: 0 auto;
74
- }
75
- /* Unsplash grid styling */
76
- .unsplash-grid img {
77
- border-radius: 8px;
78
- cursor: pointer;
79
- transition: transform 0.2s;
80
- }
81
- .unsplash-grid img:hover {
82
- transform: scale(1.05);
83
- }
84
- </style>
85
- """, unsafe_allow_html=True)
86
-
87
- # Force CUDA if available
88
- if torch.cuda.is_available():
89
- torch.cuda.set_device(0)
90
- print("CUDA device set")
91
-
92
- # GPU SETUP
93
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
94
- print(f"Using device: {device}")
95
- if device.type == 'cuda':
96
- print(f"GPU: {torch.cuda.get_device_name(0)}")
97
- print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
98
-
99
- # ===========================
100
- # UNSPLASH API INTEGRATION
101
- # ===========================
102
-
103
- class UnsplashAPI:
104
- """Simple Unsplash API integration"""
105
-
106
- def __init__(self, access_key=None):
107
- # Try to get from provided key, Streamlit secrets, or environment
108
- if access_key:
109
- self.access_key = access_key
110
- else:
111
- # Try secrets first, but handle the case where secrets don't exist
112
- try:
113
- self.access_key = st.secrets.get("UNSPLASH_ACCESS_KEY")
114
- except (FileNotFoundError, KeyError, AttributeError):
115
- # Fall back to environment variable
116
- self.access_key = os.environ.get("UNSPLASH_ACCESS_KEY")
117
-
118
- self.base_url = "https://api.unsplash.com"
119
-
120
- def search_photos(self, query, per_page=20, page=1, orientation=None):
121
- """Search photos on Unsplash"""
122
- if not self.access_key:
123
- return None, "No Unsplash API key configured"
124
-
125
- headers = {"Authorization": f"Client-ID {self.access_key}"}
126
- params = {
127
- "query": query,
128
- "per_page": per_page,
129
- "page": page
130
- }
131
-
132
- if orientation:
133
- params["orientation"] = orientation # "landscape", "portrait", "squarish"
134
-
135
- try:
136
- response = requests.get(
137
- f"{self.base_url}/search/photos",
138
- headers=headers,
139
- params=params,
140
- timeout=10
141
- )
142
- response.raise_for_status()
143
- return response.json(), None
144
- except requests.exceptions.RequestException as e:
145
- return None, f"Error searching Unsplash: {str(e)}"
146
-
147
- def get_random_photos(self, count=12, collections=None, query=None):
148
- """Get random photos from Unsplash"""
149
- if not self.access_key:
150
- return None, "No Unsplash API key configured"
151
-
152
- headers = {"Authorization": f"Client-ID {self.access_key}"}
153
- params = {"count": count}
154
-
155
- if collections:
156
- params["collections"] = collections
157
- if query:
158
- params["query"] = query
159
-
160
- try:
161
- response = requests.get(
162
- f"{self.base_url}/photos/random",
163
- headers=headers,
164
- params=params,
165
- timeout=10
166
- )
167
- response.raise_for_status()
168
- return response.json(), None
169
- except requests.exceptions.RequestException as e:
170
- return None, f"Error getting random photos: {str(e)}"
171
-
172
- def download_photo(self, photo_url, size="regular"):
173
- """Download photo from URL"""
174
- try:
175
- # Add fm=jpg&q=80 for consistent format and quality
176
- if "?" in photo_url:
177
- photo_url += "&fm=jpg&q=80"
178
- else:
179
- photo_url += "?fm=jpg&q=80"
180
-
181
- response = requests.get(photo_url, timeout=30)
182
- response.raise_for_status()
183
- return Image.open(io.BytesIO(response.content)).convert('RGB')
184
- except Exception as e:
185
- st.error(f"Error downloading image: {str(e)}")
186
- return None
187
-
188
- def trigger_download(self, download_location):
189
- """Trigger download event (required by Unsplash API)"""
190
- if not self.access_key or not download_location:
191
- return
192
-
193
- headers = {"Authorization": f"Client-ID {self.access_key}"}
194
- try:
195
- requests.get(download_location, headers=headers, timeout=5)
196
- except:
197
- pass # Don't fail if tracking fails
198
-
199
- # ===========================
200
- # MODEL ARCHITECTURES
201
- # ===========================
202
-
203
- class LightweightResidualBlock(nn.Module):
204
- """Lightweight residual block with depthwise separable convolutions"""
205
- def __init__(self, channels):
206
- super(LightweightResidualBlock, self).__init__()
207
- self.depthwise = nn.Sequential(
208
- nn.ReflectionPad2d(1),
209
- nn.Conv2d(channels, channels, 3, groups=channels),
210
- nn.InstanceNorm2d(channels, affine=True),
211
- nn.ReLU(inplace=True)
212
- )
213
- self.pointwise = nn.Sequential(
214
- nn.Conv2d(channels, channels, 1),
215
- nn.InstanceNorm2d(channels, affine=True)
216
- )
217
-
218
- def forward(self, x):
219
- return x + self.pointwise(self.depthwise(x))
220
-
221
- class ResidualBlock(nn.Module):
222
- """Standard residual block for CycleGAN"""
223
- def __init__(self, in_features):
224
- super(ResidualBlock, self).__init__()
225
- self.block = nn.Sequential(
226
- nn.ReflectionPad2d(1),
227
- nn.Conv2d(in_features, in_features, 3),
228
- nn.InstanceNorm2d(in_features, affine=True),
229
- nn.ReLU(inplace=True),
230
- nn.ReflectionPad2d(1),
231
- nn.Conv2d(in_features, in_features, 3),
232
- nn.InstanceNorm2d(in_features, affine=True)
233
- )
234
-
235
- def forward(self, x):
236
- return x + self.block(x)
237
-
238
- class Generator(nn.Module):
239
- def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=9):
240
- super(Generator, self).__init__()
241
-
242
- # Initial convolution block
243
- model = [
244
- nn.ReflectionPad2d(3),
245
- nn.Conv2d(input_nc, 64, 7),
246
- nn.InstanceNorm2d(64, affine=True),
247
- nn.ReLU(inplace=True)
248
- ]
249
-
250
- # Downsampling
251
- in_features = 64
252
- out_features = in_features * 2
253
- for _ in range(2):
254
- model += [
255
- nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
256
- nn.InstanceNorm2d(out_features, affine=True),
257
- nn.ReLU(inplace=True)
258
- ]
259
- in_features = out_features
260
- out_features = in_features * 2
261
-
262
- # Residual blocks
263
- for _ in range(n_residual_blocks):
264
- model += [ResidualBlock(in_features)]
265
-
266
- # Upsampling
267
- out_features = in_features // 2
268
- for _ in range(2):
269
- model += [
270
- nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
271
- nn.InstanceNorm2d(out_features, affine=True),
272
- nn.ReLU(inplace=True)
273
- ]
274
- in_features = out_features
275
- out_features = in_features // 2
276
-
277
- # Output layer
278
- model += [
279
- nn.ReflectionPad2d(3),
280
- nn.Conv2d(64, output_nc, 7),
281
- nn.Tanh()
282
- ]
283
-
284
- self.model = nn.Sequential(*model)
285
-
286
- def forward(self, x):
287
- return self.model(x)
288
-
289
- class LightweightStyleNet(nn.Module):
290
- """Lightweight network for fast style transfer training"""
291
- def __init__(self, n_residual_blocks=5):
292
- super(LightweightStyleNet, self).__init__()
293
-
294
- # Encoder
295
- self.encoder = nn.Sequential(
296
- nn.ReflectionPad2d(3),
297
- nn.Conv2d(3, 32, 9, stride=1),
298
- nn.InstanceNorm2d(32, affine=True),
299
- nn.ReLU(inplace=True),
300
- nn.Conv2d(32, 64, 3, stride=2, padding=1),
301
- nn.InstanceNorm2d(64, affine=True),
302
- nn.ReLU(inplace=True),
303
- nn.Conv2d(64, 128, 3, stride=2, padding=1),
304
- nn.InstanceNorm2d(128, affine=True),
305
- nn.ReLU(inplace=True)
306
- )
307
-
308
- # Residual blocks
309
- res_blocks = []
310
- for _ in range(n_residual_blocks):
311
- res_blocks.append(LightweightResidualBlock(128))
312
- self.res_blocks = nn.Sequential(*res_blocks)
313
-
314
- # Decoder
315
- self.decoder = nn.Sequential(
316
- nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
317
- nn.InstanceNorm2d(64, affine=True),
318
- nn.ReLU(inplace=True),
319
- nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
320
- nn.InstanceNorm2d(32, affine=True),
321
- nn.ReLU(inplace=True),
322
- nn.ReflectionPad2d(3),
323
- nn.Conv2d(32, 3, 9, stride=1),
324
- nn.Tanh()
325
- )
326
-
327
- def forward(self, x):
328
- h = self.encoder(x)
329
- h = self.res_blocks(h)
330
- h = self.decoder(h)
331
- return h
332
-
333
- class SimpleVGGFeatures(nn.Module):
334
- """Extract features from VGG19 for perceptual loss calculation"""
335
- def __init__(self):
336
- super(SimpleVGGFeatures, self).__init__()
337
- try:
338
- vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
339
- except:
340
- vgg = models.vgg19(pretrained=True).features
341
-
342
- self.features = nn.Sequential(*list(vgg.children())[:21])
343
-
344
- for param in self.parameters():
345
- param.requires_grad = False
346
-
347
- def forward(self, x):
348
- return self.features(x)
349
-
350
- # ===========================
351
- # DATASET AND LOSS FUNCTIONS
352
- # ===========================
353
-
354
- class StyleTransferDataset(Dataset):
355
- """Dataset for training style transfer models with augmentation support"""
356
- def __init__(self, content_dir, transform=None, augment_factor=1):
357
- self.content_dir = Path(content_dir)
358
- self.transform = transform
359
- self.augment_factor = augment_factor
360
-
361
- extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
362
- self.images = []
363
- for ext in extensions:
364
- self.images.extend(list(self.content_dir.glob(ext)))
365
- self.images.extend(list(self.content_dir.glob(ext.upper())))
366
-
367
- print(f"Found {len(self.images)} content images")
368
-
369
- self.augmented_images = self.images * self.augment_factor
370
- if self.augment_factor > 1:
371
- print(f"Dataset augmented {self.augment_factor}x to {len(self.augmented_images)} samples")
372
-
373
- def __len__(self):
374
- return len(self.augmented_images)
375
-
376
- def __getitem__(self, idx):
377
- img_path = self.augmented_images[idx % len(self.images)]
378
- image = Image.open(img_path).convert('RGB')
379
-
380
- if self.transform:
381
- image = self.transform(image)
382
-
383
- return image
384
-
385
- class PerceptualLoss(nn.Module):
386
- """Perceptual loss using VGG features"""
387
- def __init__(self, vgg_features):
388
- super(PerceptualLoss, self).__init__()
389
- self.vgg = vgg_features
390
- self.mse = nn.MSELoss()
391
-
392
- def gram_matrix(self, features):
393
- b, c, h, w = features.size()
394
- features = features.view(b, c, h * w)
395
- gram = torch.bmm(features, features.transpose(1, 2))
396
- return gram / (c * h * w)
397
-
398
- def forward(self, generated, content, style, content_weight=1.0, style_weight=1e5):
399
- gen_feat = self.vgg(generated)
400
- content_feat = self.vgg(content)
401
- style_feat = self.vgg(style)
402
-
403
- content_loss = self.mse(gen_feat, content_feat)
404
-
405
- gen_gram = self.gram_matrix(gen_feat)
406
- style_gram = self.gram_matrix(style_feat)
407
- style_loss = self.mse(gen_gram, style_gram)
408
-
409
- total_loss = content_weight * content_loss + style_weight * style_loss
410
-
411
- return total_loss, content_loss, style_loss
412
-
413
- # ===========================
414
- # VIDEO PROCESSING
415
- # ===========================
416
-
417
- class VideoProcessor:
418
- """Process videos frame by frame with style transfer"""
419
-
420
- def __init__(self, system):
421
- self.system = system
422
-
423
- def process_video(self, video_path, style_configs, blend_mode, progress_callback=None):
424
  """Process a video file with style transfer"""
425
  if not VIDEO_PROCESSING_AVAILABLE:
426
  print("Video processing requires OpenCV (cv2) - please install it")
@@ -438,31 +16,22 @@ class VideoProcessor:
438
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
439
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
440
 
441
- # Create temporary output file
442
  temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
443
  temp_output.close() # Close so OpenCV can write
444
 
445
- # Try different codecs
446
- codecs_to_try = ['H264', 'h264', 'avc1', 'mp4v', 'XVID']
447
- out = None
448
- codec_used = None
449
 
450
- for codec in codecs_to_try:
451
- try:
452
- fourcc = cv2.VideoWriter_fourcc(*codec)
453
- out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
454
- if out.isOpened():
455
- codec_used = codec
456
- print(f"Using video codec: {codec}")
457
- break
458
- except:
459
- continue
460
 
461
- if out is None or not out.isOpened():
462
- # Fallback to AVI with MJPEG
463
- print("Falling back to AVI format with MJPEG codec")
464
- temp_output = tempfile.NamedTemporaryFile(suffix='.avi', delete=False)
465
- fourcc = cv2.VideoWriter_fourcc(*'MJPG')
466
  out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
467
 
468
  if not out.isOpened():
@@ -495,1661 +64,64 @@ class VideoProcessor:
495
  cap.release()
496
  out.release()
497
 
498
- # If we used a non-standard codec, try to convert to H264
499
- if codec_used not in ['H264', 'h264', 'avc1'] and temp_output.name.endswith('.mp4'):
500
- print("Converting to H264 for better compatibility...")
501
- try:
502
- converted_output = tempfile.NamedTemporaryFile(suffix='_h264.mp4', delete=False)
503
- converted_output.close()
504
-
505
- # Re-encode with H264
506
- cap = cv2.VideoCapture(temp_output.name)
507
- fourcc = cv2.VideoWriter_fourcc(*'H264')
508
- out = cv2.VideoWriter(converted_output.name, fourcc, fps, (width, height))
509
-
510
- while True:
511
- ret, frame = cap.read()
512
- if not ret:
513
- break
514
- out.write(frame)
515
-
516
- cap.release()
517
- out.release()
518
-
519
- # Replace with converted version
520
- os.unlink(temp_output.name)
521
- return converted_output.name
522
- except:
523
- print("H264 conversion failed, using original")
524
-
525
- return temp_output.name
526
-
527
- except Exception as e:
528
- print(f"Error processing video: {e}")
529
- traceback.print_exc()
530
- return None
531
-
532
- # ===========================
533
- # MAIN STYLE TRANSFER SYSTEM
534
- # ===========================
535
-
536
- class StyleTransferSystem:
537
- def __init__(self):
538
- self.device = device
539
- self.cyclegan_models = {}
540
- self.loaded_generators = {}
541
- self.lightweight_models = {}
542
-
543
- self.transform = transforms.Compose([
544
- transforms.ToTensor(),
545
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
546
- ])
547
-
548
- self.inverse_transform = transforms.Compose([
549
- transforms.Normalize((-1, -1, -1), (2, 2, 2)),
550
- transforms.ToPILImage()
551
- ])
552
-
553
- self.vgg_transform = transforms.Compose([
554
- transforms.ToTensor(),
555
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
556
- std=[0.229, 0.224, 0.225])
557
- ])
558
-
559
- self.discover_cyclegan_models()
560
- self.models_dir = '/tmp/trained_models'
561
- os.makedirs(self.models_dir, exist_ok=True)
562
-
563
- if VIDEO_PROCESSING_AVAILABLE:
564
- self.video_processor = VideoProcessor(self)
565
-
566
- def discover_cyclegan_models(self):
567
- """Find all available CycleGAN models including both AB and BA directions"""
568
- print("\nDiscovering CycleGAN models...")
569
-
570
- # Updated patterns to match your directory structure
571
- patterns = [
572
- './models/*_best_*/*generator_*.pth',
573
- './models/*_best_*/*.pth',
574
- './models/*/*generator*.pth',
575
- './models/*/*.pth'
576
- ]
577
-
578
- all_files = set()
579
- for pattern in patterns:
580
- files = glob.glob(pattern)
581
- if files:
582
- print(f"Found in {pattern}: {len(files)} items")
583
- all_files.update(files)
584
-
585
- # Also check if models directory exists and list contents
586
- if os.path.exists('./models'):
587
- print(f"\nModels directory contents:")
588
- for folder in os.listdir('./models'):
589
- folder_path = os.path.join('./models', folder)
590
- if os.path.isdir(folder_path):
591
- print(f" {folder}/")
592
- for file in os.listdir(folder_path):
593
- print(f" - {file}")
594
- if file.endswith('.pth'):
595
- all_files.add(os.path.join(folder_path, file))
596
-
597
- # Group files by base model name
598
- model_files = {}
599
- for path in all_files:
600
- # Skip normal models
601
- if 'normal' in path.lower():
602
- continue
603
-
604
- filename = os.path.basename(path)
605
- folder_name = os.path.basename(os.path.dirname(path))
606
-
607
- # Extract base name from folder name
608
- if '_best_' in folder_name:
609
- base_name = folder_name.split('_best_')[0]
610
- else:
611
- base_name = folder_name
612
-
613
- if base_name not in model_files:
614
- model_files[base_name] = {'AB': None, 'BA': None}
615
 
616
- # Check filename for direction
617
- if 'generator_AB' in filename or 'g_AB' in filename or 'G_AB' in filename:
618
- model_files[base_name]['AB'] = path
619
- elif 'generator_BA' in filename or 'g_BA' in filename or 'G_BA' in filename:
620
- model_files[base_name]['BA'] = path
621
- elif 'generator' in filename.lower() and not any(x in filename for x in ['AB', 'BA']):
622
- # If no direction specified, assume it's AB
623
- if model_files[base_name]['AB'] is None:
624
- model_files[base_name]['AB'] = path
625
-
626
- # Create display names for models
627
- model_display_map = {
628
- 'photo_bokeh': ('Bokeh', 'Sharp'),
629
- 'photo_golden': ('Golden Hour', 'Normal Light'),
630
- 'photo_monet': ('Monet Style', 'Photo'),
631
- 'photo_seurat': ('Seurat Style', 'Photo'),
632
- 'day_night': ('Night', 'Day'),
633
- 'summer_winter': ('Winter', 'Summer'),
634
- 'foggy_clear': ('Clear', 'Foggy')
635
- }
636
-
637
- # Register available models
638
- for base_name, files in model_files.items():
639
- clean_name = base_name.lower().replace('-', '_')
640
-
641
- if clean_name in model_display_map:
642
- style_from, style_to = model_display_map[clean_name]
643
-
644
- # Register AB direction if available
645
- if files['AB']:
646
- display_name = f"{style_to} to {style_from}"
647
- model_key = f"{clean_name}_AB"
648
-
649
- self.cyclegan_models[model_key] = {
650
- 'path': files['AB'],
651
- 'name': display_name,
652
- 'base_name': base_name,
653
- 'direction': 'AB'
654
- }
655
- print(f"Registered: {display_name} ({model_key}) -> {files['AB']}")
656
-
657
- # Register BA direction if available
658
- if files['BA']:
659
- display_name = f"{style_from} to {style_to}"
660
- model_key = f"{clean_name}_BA"
661
-
662
- self.cyclegan_models[model_key] = {
663
- 'path': files['BA'],
664
- 'name': display_name,
665
- 'base_name': base_name,
666
- 'direction': 'BA'
667
- }
668
- print(f"Registered: {display_name} ({model_key}) -> {files['BA']}")
669
-
670
- if not self.cyclegan_models:
671
- print("No CycleGAN models found!")
672
- print("Make sure your model files are in the ./models directory")
673
- else:
674
- print(f"\nFound {len(self.cyclegan_models)} CycleGAN models\n")
675
-
676
- def detect_architecture(self, state_dict):
677
- """Detect the number of residual blocks in CycleGAN model"""
678
- residual_keys = [k for k in state_dict.keys() if 'model.' in k and '.block.' in k]
679
-
680
- if not residual_keys:
681
- return 9
682
-
683
- block_indices = set()
684
- for key in residual_keys:
685
- parts = key.split('.')
686
- for i in range(len(parts) - 1):
687
- if parts[i] == 'model' and parts[i+1].isdigit():
688
- block_indices.add(int(parts[i+1]))
689
- break
690
-
691
- n_blocks = len(block_indices)
692
- return n_blocks if n_blocks > 0 else 9
693
-
694
- def load_cyclegan_model(self, model_key):
695
- """Load a CycleGAN model"""
696
- if model_key in self.loaded_generators:
697
- return self.loaded_generators[model_key]
698
-
699
- if model_key not in self.cyclegan_models:
700
- print(f"Model {model_key} not found!")
701
- return None
702
-
703
- model_info = self.cyclegan_models[model_key]
704
-
705
- try:
706
- print(f"Loading {model_info['name']} from {model_info['path']}...")
707
-
708
- state_dict = torch.load(model_info['path'], map_location=self.device)
709
- if 'generator' in state_dict:
710
- state_dict = state_dict['generator']
711
-
712
- n_blocks = self.detect_architecture(state_dict)
713
- print(f"Detected {n_blocks} residual blocks")
714
-
715
- generator = Generator(n_residual_blocks=n_blocks)
716
-
717
- try:
718
- generator.load_state_dict(state_dict, strict=True)
719
- print(f"Loaded with strict=True")
720
- except:
721
- generator.load_state_dict(state_dict, strict=False)
722
- print(f"Loaded with strict=False")
723
-
724
- generator.to(self.device)
725
- generator.eval()
726
-
727
- if self.device.type == 'cuda':
728
- try:
729
- generator = generator.half()
730
- print("Using half precision (fp16)")
731
- except:
732
- print("Using full precision (fp32)")
733
-
734
- self.loaded_generators[model_key] = generator
735
- print(f"Successfully loaded {model_info['name']}")
736
- return generator
737
-
738
- except Exception as e:
739
- print(f"Failed to load {model_info['name']}: {e}")
740
- traceback.print_exc()
741
- return None
742
-
743
- def apply_cyclegan_style(self, image, model_key, intensity=1.0):
744
- """Apply a CycleGAN style to an image"""
745
- if image is None or model_key not in self.cyclegan_models:
746
- return None
747
-
748
- model_info = self.cyclegan_models[model_key]
749
- generator = self.load_cyclegan_model(model_key)
750
-
751
- if generator is None:
752
- print(f"Could not load model for {model_info['name']}")
753
- return None
754
-
755
- try:
756
- original_size = image.size
757
-
758
- w, h = image.size
759
- new_w = ((w + 31) // 32) * 32
760
- new_h = ((h + 31) // 32) * 32
761
-
762
- max_size = 1024 if self.device.type == 'cuda' else 512
763
- if new_w > max_size or new_h > max_size:
764
- ratio = min(max_size / new_w, max_size / new_h)
765
- new_w = int(new_w * ratio)
766
- new_h = int(new_h * ratio)
767
- new_w = ((new_w + 31) // 32) * 32
768
- new_h = ((new_h + 31) // 32) * 32
769
-
770
- image_resized = image.resize((new_w, new_h), Image.LANCZOS)
771
- img_tensor = self.transform(image_resized).unsqueeze(0).to(self.device)
772
-
773
- with torch.no_grad():
774
- is_half = next(generator.parameters()).dtype == torch.float16
775
-
776
- if self.device.type == 'cuda' and is_half:
777
- img_tensor = img_tensor.half()
778
-
779
- if self.device.type == 'cuda':
780
- torch.cuda.empty_cache()
781
-
782
- output = generator(img_tensor)
783
-
784
- if output.dtype == torch.float16:
785
- output = output.float()
786
-
787
- output_img = self.inverse_transform(output.squeeze(0).cpu())
788
- output_img = output_img.resize(original_size, Image.LANCZOS)
789
-
790
- if self.device.type == 'cuda':
791
- torch.cuda.empty_cache()
792
-
793
- if intensity < 1.0:
794
- output_array = np.array(output_img, dtype=np.float32)
795
- original_array = np.array(image, dtype=np.float32)
796
- blended = original_array * (1 - intensity) + output_array * intensity
797
- output_img = Image.fromarray(blended.astype(np.uint8))
798
-
799
- return output_img
800
-
801
- except Exception as e:
802
- print(f"Error applying style {model_info['name']}: {e}")
803
- traceback.print_exc()
804
- return None
805
-
806
- def train_lightweight_model(self, style_image, content_dir, model_name,
807
- epochs=30, batch_size=4, lr=1e-3,
808
- save_interval=5, style_weight=1e5, content_weight=1.0,
809
- n_residual_blocks=5, progress_callback=None):
810
- """Train a lightweight style transfer model"""
811
-
812
- model = LightweightStyleNet(n_residual_blocks=n_residual_blocks).to(self.device)
813
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
814
-
815
- print(f"Model architecture: {n_residual_blocks} residual blocks")
816
-
817
- # Calculate augmentation factor
818
- num_content_images = len(list(Path(content_dir).glob('*')))
819
- if num_content_images < 5:
820
- augment_factor = 20
821
- elif num_content_images < 10:
822
- augment_factor = 10
823
- elif num_content_images < 20:
824
- augment_factor = 5
825
- else:
826
- augment_factor = 1
827
-
828
- # Create dataset with augmentation
829
- if num_content_images < 10:
830
- transform = transforms.Compose([
831
- transforms.RandomResizedCrop(256, scale=(0.7, 1.2)),
832
- transforms.RandomHorizontalFlip(),
833
- transforms.RandomRotation(15),
834
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
835
- transforms.ToTensor(),
836
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
837
- ])
838
- print(f"Using heavy augmentation due to limited images ({num_content_images} provided)")
839
- else:
840
- transform = transforms.Compose([
841
- transforms.Resize(286),
842
- transforms.RandomCrop(256),
843
- transforms.RandomHorizontalFlip(),
844
- transforms.ToTensor(),
845
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
846
- ])
847
-
848
- dataset = StyleTransferDataset(content_dir, transform=transform, augment_factor=augment_factor)
849
-
850
- print(f"Training configuration:")
851
- print(f" - Original images: {num_content_images}")
852
- print(f" - Augmentation factor: {augment_factor}x")
853
- print(f" - Total training samples: {len(dataset)}")
854
- print(f" - Residual blocks: {n_residual_blocks}")
855
- print(f" - Batch size: {int(batch_size)}")
856
- print(f" - Epochs: {epochs}")
857
-
858
- # Adjust batch size for small datasets
859
- if num_content_images == 1:
860
- if n_residual_blocks >= 9 and int(batch_size) > 1:
861
- actual_batch_size = 1
862
- print(f"Reduced batch size to 1 for single image + {n_residual_blocks} blocks")
863
- elif int(batch_size) > 2:
864
- actual_batch_size = 2
865
- print(f"Reduced batch size to 2 for single image training")
866
- else:
867
- actual_batch_size = min(int(batch_size), len(dataset))
868
- else:
869
- actual_batch_size = min(int(batch_size), len(dataset))
870
-
871
- dataloader = DataLoader(dataset, batch_size=actual_batch_size, shuffle=True,
872
- num_workers=0 if num_content_images < 10 else 2)
873
-
874
- # Prepare style image
875
- style_transform = transforms.Compose([
876
- transforms.Resize(256),
877
- transforms.CenterCrop(256)
878
- ])
879
- style_pil = style_transform(style_image)
880
- style_tensor = self.vgg_transform(style_pil).unsqueeze(0).to(self.device)
881
-
882
- # Create VGG features extractor for loss
883
- vgg_features = SimpleVGGFeatures().to(self.device).eval()
884
-
885
- # Extract style features once
886
- with torch.no_grad():
887
- style_features = vgg_features(style_tensor)
888
-
889
- # Loss function
890
- perceptual_loss = PerceptualLoss(vgg_features)
891
-
892
- # Training loop
893
- model.train()
894
- total_steps = 0
895
-
896
- for epoch in range(epochs):
897
- epoch_loss = 0
898
-
899
- for batch_idx, content_batch in enumerate(dataloader):
900
- content_batch = content_batch.to(self.device)
901
-
902
- # Forward pass
903
- output = model(content_batch)
904
-
905
- # Ensure all tensors have the same size
906
- target_size = (256, 256)
907
-
908
- # Convert for VGG
909
- output_vgg = []
910
- content_vgg = []
911
-
912
- for i in range(output.size(0)):
913
- # Denormalize from [-1, 1] to [0, 1]
914
- out_img = output[i] * 0.5 + 0.5
915
- cont_img = content_batch[i] * 0.5 + 0.5
916
-
917
- # Ensure exact size match
918
- if out_img.shape[1:] != (target_size[0], target_size[1]):
919
- out_img = F.interpolate(out_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0)
920
- if cont_img.shape[1:] != (target_size[0], target_size[1]):
921
- cont_img = F.interpolate(cont_img.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0)
922
-
923
- # Normalize for VGG
924
- out_norm = transforms.Normalize(
925
- mean=[0.485, 0.456, 0.406],
926
- std=[0.229, 0.224, 0.225]
927
- )(out_img)
928
- cont_norm = transforms.Normalize(
929
- mean=[0.485, 0.456, 0.406],
930
- std=[0.229, 0.224, 0.225]
931
- )(cont_img)
932
-
933
- output_vgg.append(out_norm)
934
- content_vgg.append(cont_norm)
935
-
936
- output_vgg = torch.stack(output_vgg)
937
- content_vgg = torch.stack(content_vgg)
938
-
939
- # Ensure style tensor matches batch size and dimensions
940
- style_vgg = style_tensor.expand(output_vgg.size(0), -1, -1, -1)
941
- if style_vgg.shape[2:] != output_vgg.shape[2:]:
942
- style_vgg = F.interpolate(style_vgg, size=output_vgg.shape[2:], mode='bilinear', align_corners=False)
943
-
944
- # Calculate loss
945
- loss, content_loss, style_loss = perceptual_loss(
946
- output_vgg, content_vgg, style_vgg,
947
- content_weight=content_weight, style_weight=style_weight
948
- )
949
-
950
- # Backward pass
951
- optimizer.zero_grad()
952
- loss.backward()
953
- optimizer.step()
954
-
955
- epoch_loss += loss.item()
956
- total_steps += 1
957
-
958
- # Progress callback
959
- if progress_callback and total_steps % 10 == 0:
960
- progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
961
- aug_info = f" (aug {num_content_images}→{len(dataset)})" if num_content_images < 20 else ""
962
- blocks_info = f", {n_residual_blocks} blocks"
963
- progress_callback(progress, f"Epoch {epoch+1}/{epochs}{aug_info}{blocks_info}, Loss: {loss.item():.4f}")
964
-
965
- # Save checkpoint
966
- if (epoch + 1) % int(save_interval) == 0:
967
- checkpoint_path = f'{self.models_dir}/{model_name}_epoch_{epoch+1}.pth'
968
- torch.save({
969
- 'epoch': epoch + 1,
970
- 'model_state_dict': model.state_dict(),
971
- 'optimizer_state_dict': optimizer.state_dict(),
972
- 'loss': epoch_loss / len(dataloader),
973
- 'n_residual_blocks': n_residual_blocks
974
- }, checkpoint_path)
975
- print(f"Saved checkpoint: {checkpoint_path}")
976
-
977
- # Save final model
978
- final_path = f'{self.models_dir}/{model_name}_final.pth'
979
- torch.save({
980
- 'model_state_dict': model.state_dict(),
981
- 'n_residual_blocks': n_residual_blocks
982
- }, final_path)
983
- print(f"Training complete! Model saved to: {final_path}")
984
-
985
- # Add to lightweight models
986
- self.lightweight_models[model_name] = model
987
-
988
- return model
989
-
990
- def load_lightweight_model(self, model_path):
991
- """Load a trained lightweight model"""
992
- try:
993
- state_dict = torch.load(model_path, map_location=self.device)
994
-
995
- # Check if n_residual_blocks is saved
996
- if isinstance(state_dict, dict) and 'n_residual_blocks' in state_dict:
997
- n_blocks = state_dict['n_residual_blocks']
998
- print(f"Found saved architecture: {n_blocks} residual blocks")
999
- else:
1000
- # Try to detect from state dict
1001
- if 'model_state_dict' in state_dict:
1002
- model_state = state_dict['model_state_dict']
1003
- else:
1004
- model_state = state_dict
1005
-
1006
- res_block_keys = [k for k in model_state.keys() if 'res_blocks' in k and 'weight' in k]
1007
- n_blocks = len(set([k.split('.')[1] for k in res_block_keys if k.startswith('res_blocks')])) or 5
1008
- print(f"Detected {n_blocks} residual blocks from model structure")
1009
-
1010
- # Create model with detected architecture
1011
- model = LightweightStyleNet(n_residual_blocks=n_blocks).to(self.device)
1012
-
1013
- # Load the weights
1014
- if 'model_state_dict' in state_dict:
1015
- model.load_state_dict(state_dict['model_state_dict'])
1016
- else:
1017
- model.load_state_dict(state_dict)
1018
-
1019
- model.eval()
1020
- return model
1021
-
1022
- except Exception as e:
1023
- print(f"Error loading lightweight model: {e}")
1024
- # Try with default 5 blocks
1025
  try:
1026
- print("Attempting to load with default 5 residual blocks...")
1027
- model = LightweightStyleNet(n_residual_blocks=5).to(self.device)
1028
-
1029
- if model_path.endswith('.pth'):
1030
- state_dict = torch.load(model_path, map_location=self.device)
1031
- if 'model_state_dict' in state_dict:
1032
- model.load_state_dict(state_dict['model_state_dict'])
1033
- else:
1034
- model.load_state_dict(state_dict)
1035
-
1036
- model.eval()
1037
- return model
1038
- except:
1039
- return None
1040
-
1041
- def apply_lightweight_style(self, image, model, intensity=1.0):
1042
- """Apply style using a lightweight model"""
1043
- if image is None or model is None:
1044
- return None
1045
-
1046
- try:
1047
- original_size = image.size
1048
-
1049
- transform = transforms.Compose([
1050
- transforms.Resize(256),
1051
- transforms.CenterCrop(256),
1052
- transforms.ToTensor(),
1053
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
1054
- ])
1055
-
1056
- img_tensor = transform(image).unsqueeze(0).to(self.device)
1057
-
1058
- with torch.no_grad():
1059
- output = model(img_tensor)
1060
- output_img = self.inverse_transform(output.squeeze(0).cpu())
1061
- output_img = output_img.resize(original_size, Image.LANCZOS)
1062
-
1063
- if intensity < 1.0:
1064
- output_array = np.array(output_img, dtype=np.float32)
1065
- original_array = np.array(image, dtype=np.float32)
1066
- blended = original_array * (1 - intensity) + output_array * intensity
1067
- output_img = Image.fromarray(blended.astype(np.uint8))
1068
-
1069
- return output_img
1070
-
1071
- except Exception as e:
1072
- print(f"Error applying lightweight style: {e}")
1073
- return None
1074
-
1075
- def blend_styles(self, image, style_configs, blend_mode="additive"):
1076
- """Apply multiple styles with different blending modes"""
1077
- if not image or not style_configs:
1078
- return image
1079
-
1080
- original = np.array(image, dtype=np.float32)
1081
- styled_images = []
1082
- weights = []
1083
-
1084
- for style_type, model_key, intensity in style_configs:
1085
- if intensity <= 0:
1086
- continue
1087
-
1088
- if style_type == 'cyclegan':
1089
- styled = self.apply_cyclegan_style(image, model_key, 1.0)
1090
- elif style_type == 'lightweight' and model_key in self.lightweight_models:
1091
- styled = self.apply_lightweight_style(image, self.lightweight_models[model_key], 1.0)
1092
- else:
1093
- continue
1094
-
1095
- if styled:
1096
- styled_images.append(np.array(styled, dtype=np.float32))
1097
- weights.append(intensity)
1098
-
1099
- if not styled_images:
1100
- return image
1101
-
1102
- # Apply blending
1103
- if blend_mode == "average":
1104
- result = np.zeros_like(original)
1105
- total_weight = sum(weights)
1106
- for img, weight in zip(styled_images, weights):
1107
- result += img * (weight / total_weight)
1108
-
1109
- elif blend_mode == "additive":
1110
- result = original.copy()
1111
- for img, weight in zip(styled_images, weights):
1112
- transformation = img - original
1113
- result = result + transformation * weight
1114
-
1115
- elif blend_mode == "maximum":
1116
- result = original.copy()
1117
- for img, weight in zip(styled_images, weights):
1118
- transformation = (img - original) * weight
1119
- current_diff = result - original
1120
- mask = np.abs(transformation) > np.abs(current_diff)
1121
- result[mask] = original[mask] + transformation[mask]
1122
-
1123
- elif blend_mode == "overlay":
1124
- result = original.copy()
1125
- for img, weight in zip(styled_images, weights):
1126
- overlay = np.zeros_like(result)
1127
- mask = result < 128
1128
- overlay[mask] = 2 * img[mask] * result[mask] / 255.0
1129
- overlay[~mask] = 255 - 2 * (255 - img[~mask]) * (255 - result[~mask]) / 255.0
1130
- result = result * (1 - weight) + overlay * weight
1131
-
1132
- else: # "screen" mode
1133
- result = original.copy()
1134
- for img, weight in zip(styled_images, weights):
1135
- screened = 255 - ((255 - result) * (255 - img) / 255.0)
1136
- if weight > 1.0:
1137
- diff = screened - result
1138
- result = result + diff * weight
1139
- else:
1140
- result = result * (1 - weight) + screened * weight
1141
-
1142
- return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
1143
-
1144
- def apply_regional_styles(self, image, combined_mask, regions, base_style_configs=None, blend_mode="additive"):
1145
- """Apply different styles to painted regions using a combined mask"""
1146
- if not regions:
1147
- if base_style_configs:
1148
- return self.blend_styles(image, base_style_configs, blend_mode)
1149
- return image
1150
-
1151
- original_size = image.size
1152
- result = np.array(image, dtype=np.float32)
1153
-
1154
- # Apply base style if provided
1155
- if base_style_configs:
1156
- base_styled = self.blend_styles(image, base_style_configs, blend_mode)
1157
- result = np.array(base_styled, dtype=np.float32)
1158
-
1159
- # Resize mask to match original image if needed
1160
- if combined_mask is not None and combined_mask.shape[:2] != (original_size[1], original_size[0]):
1161
- # Resize the combined mask to match the original image
1162
- combined_mask_pil = Image.fromarray(combined_mask.astype(np.uint8))
1163
- combined_mask_resized = combined_mask_pil.resize(original_size, Image.NEAREST)
1164
- combined_mask = np.array(combined_mask_resized)
1165
-
1166
- # Apply each region
1167
- for i, region in enumerate(regions):
1168
- if region['style'] is None:
1169
- continue
1170
-
1171
- # Get model key for this region's style
1172
- model_key = None
1173
- for key, info in self.cyclegan_models.items():
1174
- if info['name'] == region['style']:
1175
- model_key = key
1176
- break
1177
-
1178
- if not model_key:
1179
- continue
1180
-
1181
- # Apply style to whole image
1182
- style_configs = [('cyclegan', model_key, region['intensity'])]
1183
- styled = self.blend_styles(image, style_configs, blend_mode)
1184
- styled_array = np.array(styled, dtype=np.float32)
1185
-
1186
- # Create mask for this region from combined mask
1187
- if combined_mask is not None:
1188
- # Region masks are identified by their color index
1189
- region_mask = (combined_mask == (i + 1)).astype(np.float32)
1190
- # Ensure mask has same shape as image
1191
- if len(region_mask.shape) == 2:
1192
- region_mask_3ch = np.stack([region_mask] * 3, axis=2)
1193
- else:
1194
- region_mask_3ch = region_mask
1195
-
1196
- # Blend using mask
1197
- result = result * (1 - region_mask_3ch) + styled_array * region_mask_3ch
1198
-
1199
- return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
1200
-
1201
- # ===========================
1202
- # HELPER FUNCTIONS
1203
- # ===========================
1204
-
1205
- def resize_image_for_display(image, max_width=800, max_height=600):
1206
- """Resize image for display while maintaining aspect ratio"""
1207
- width, height = image.size
1208
-
1209
- # Calculate scaling factor
1210
- width_scale = max_width / width
1211
- height_scale = max_height / height
1212
- scale = min(width_scale, height_scale)
1213
-
1214
- # Only scale down, not up
1215
- if scale < 1:
1216
- new_width = int(width * scale)
1217
- new_height = int(height * scale)
1218
- return image.resize((new_width, new_height), Image.LANCZOS)
1219
-
1220
- return image
1221
-
1222
- def combine_region_masks(canvas_results, canvas_size):
1223
- """Combine multiple region masks into a single mask with different values for each region"""
1224
- combined_mask = np.zeros(canvas_size[:2], dtype=np.uint8)
1225
-
1226
- for i, canvas_data in enumerate(canvas_results):
1227
- if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None:
1228
- # Extract alpha channel as mask
1229
- mask = canvas_data.image_data[:, :, 3] > 0
1230
- # Assign region index (1-based) to mask
1231
- combined_mask[mask] = i + 1
1232
-
1233
- return combined_mask
1234
-
1235
- # ===========================
1236
- # INITIALIZE SYSTEM AND API
1237
- # ===========================
1238
-
1239
- @st.cache_resource
1240
- def load_system():
1241
- return StyleTransferSystem()
1242
-
1243
- @st.cache_resource
1244
- def get_unsplash_api():
1245
- return UnsplashAPI()
1246
-
1247
- system = load_system()
1248
- unsplash = get_unsplash_api()
1249
-
1250
- # Get style choices
1251
- style_choices = sorted([info['name'] for info in system.cyclegan_models.values()])
1252
-
1253
- # ===========================
1254
- # STREAMLIT APP
1255
- # ===========================
1256
-
1257
- # Main app
1258
- st.title("🎨 Style Transfer Studio")
1259
- st.markdown("Professional image and video style transfer with CycleGAN and custom training capabilities")
1260
-
1261
- # Sidebar for global settings
1262
- with st.sidebar:
1263
- st.header("⚙️ Settings")
1264
-
1265
- # GPU status
1266
- if torch.cuda.is_available():
1267
- gpu_info = torch.cuda.get_device_properties(0)
1268
- st.success(f"🚀 GPU: {gpu_info.name}")
1269
- st.metric("GPU Memory", f"{gpu_info.total_memory / 1e9:.2f} GB")
1270
- else:
1271
- st.warning("💻 Running on CPU")
1272
-
1273
- st.markdown("---")
1274
- st.markdown("### 📚 Quick Guide")
1275
- st.markdown("""
1276
- - **Style Transfer**: Apply artistic styles to images
1277
- - **Regional Transform**: Paint areas for local effects
1278
- - **Video Processing**: Apply styles to videos
1279
- - **Train Custom**: Create your own style models
1280
- - **Batch Process**: Process multiple images
1281
- """)
1282
-
1283
- # Unsplash API status
1284
- st.markdown("---")
1285
- if unsplash.access_key:
1286
- st.success("🔗 Unsplash API Connected")
1287
- else:
1288
- st.info("💡 Add Unsplash API key for image search")
1289
-
1290
- # Main tabs
1291
- tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs([
1292
- "🎨 Style Transfer",
1293
- "🖌️ Regional Transform",
1294
- "🎬 Video Processing",
1295
- "🔧 Train Custom Style",
1296
- "📦 Batch Processing",
1297
- "📖 Documentation"
1298
- ])
1299
-
1300
- # TAB 1: Style Transfer (with Unsplash integration)
1301
- with tab1:
1302
- # Unsplash Search Section
1303
- with st.expander("🔍 Search Unsplash for Images", expanded=False):
1304
- if not unsplash.access_key:
1305
- st.info("""
1306
- To enable Unsplash search:
1307
- 1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers)
1308
- 2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY`
1309
- """)
1310
- else:
1311
- search_col1, search_col2, search_col3 = st.columns([3, 1, 1])
1312
- with search_col1:
1313
- search_query = st.text_input("Search for images", placeholder="e.g., landscape, portrait, abstract art")
1314
- with search_col2:
1315
- orientation = st.selectbox("Orientation", ["all", "landscape", "portrait", "squarish"])
1316
- with search_col3:
1317
- search_button = st.button("🔍 Search", use_container_width=True)
1318
-
1319
- # Random photos button
1320
- if st.button("🎲 Get Random Photos"):
1321
- with st.spinner("Loading random photos..."):
1322
- results, error = unsplash.get_random_photos(count=12)
1323
-
1324
- if error:
1325
- st.error(f"Error: {error}")
1326
- elif results:
1327
- # Handle both single photo and array of photos
1328
- photos = results if isinstance(results, list) else [results]
1329
- st.session_state['unsplash_results'] = photos
1330
- st.success(f"Loaded {len(photos)} random photos")
1331
-
1332
- # Search functionality
1333
- if search_button and search_query:
1334
- with st.spinner(f"Searching for '{search_query}'..."):
1335
- orientation_param = None if orientation == "all" else orientation
1336
- results, error = unsplash.search_photos(search_query, per_page=12, orientation=orientation_param)
1337
-
1338
- if error:
1339
- st.error(f"Error: {error}")
1340
- elif results and results.get('results'):
1341
- st.session_state['unsplash_results'] = results['results']
1342
- st.success(f"Found {results['total']} images")
1343
- else:
1344
- st.info("No images found. Try a different search term.")
1345
-
1346
- # Display results
1347
- if 'unsplash_results' in st.session_state and st.session_state['unsplash_results']:
1348
- st.markdown("### Search Results")
1349
-
1350
- # Display in a 4-column grid
1351
- cols = st.columns(4)
1352
- for idx, photo in enumerate(st.session_state['unsplash_results'][:12]):
1353
- with cols[idx % 4]:
1354
- # Show thumbnail
1355
- st.image(photo['urls']['thumb'], use_column_width=True)
1356
-
1357
- # Photo info
1358
- st.caption(f"By {photo['user']['name']}")
1359
-
1360
- # Use button
1361
- if st.button("Use This", key=f"use_unsplash_{photo['id']}"):
1362
- with st.spinner("Loading image..."):
1363
- # Download regular size
1364
- img = unsplash.download_photo(photo['urls']['regular'])
1365
- if img:
1366
- # Store in session state
1367
- st.session_state['current_image'] = img
1368
- st.session_state['image_source'] = f"Unsplash: {photo['user']['name']}"
1369
- st.session_state['unsplash_photo'] = photo
1370
-
1371
- # Trigger download tracking (required by Unsplash)
1372
- if 'links' in photo and 'download_location' in photo['links']:
1373
- unsplash.trigger_download(photo['links']['download_location'])
1374
-
1375
- st.success("Image loaded!")
1376
- st.rerun()
1377
-
1378
- col1, col2 = st.columns(2)
1379
-
1380
- with col1:
1381
- st.header("Input")
1382
-
1383
- # Image source selection
1384
- image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True)
1385
-
1386
- # Initialize input_image to None
1387
- input_image = None
1388
-
1389
- if image_source == "Upload":
1390
- uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'])
1391
- if uploaded_file:
1392
- input_image = Image.open(uploaded_file).convert('RGB')
1393
- st.session_state['current_image'] = input_image
1394
- st.session_state['image_source'] = "Uploaded"
1395
- else:
1396
- # Handle Unsplash selection
1397
- if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
1398
- input_image = st.session_state['current_image']
1399
- else:
1400
- st.info("Search for an image above")
1401
-
1402
- if input_image:
1403
- # Display the image
1404
- display_img = resize_image_for_display(input_image, max_width=600, max_height=400)
1405
- st.image(display_img, caption=st.session_state.get('image_source', 'Image'), use_column_width=True)
1406
-
1407
- # Attribution for Unsplash images
1408
- if 'unsplash_photo' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
1409
- photo = st.session_state['unsplash_photo']
1410
- st.markdown(f"Photo by [{photo['user']['name']}]({photo['user']['links']['html']}) on [Unsplash]({photo['links']['html']})")
1411
-
1412
- st.subheader("Style Configuration")
1413
-
1414
- # Up to 3 styles
1415
- num_styles = st.number_input("Number of styles to apply", 1, 3, 1)
1416
-
1417
- style_configs = []
1418
- for i in range(num_styles):
1419
- with st.expander(f"Style {i+1}", expanded=(i==0)):
1420
- style = st.selectbox(f"Select style", style_choices, key=f"style_{i}")
1421
- intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"intensity_{i}")
1422
- if style and intensity > 0:
1423
- model_key = None
1424
- for key, info in system.cyclegan_models.items():
1425
- if info['name'] == style:
1426
- model_key = key
1427
- break
1428
- if model_key:
1429
- style_configs.append(('cyclegan', model_key, intensity))
1430
-
1431
- blend_mode = st.selectbox("Blend Mode",
1432
- ["additive", "average", "maximum", "overlay", "screen"],
1433
- index=0)
1434
-
1435
- if st.button("Apply Styles", type="primary", use_container_width=True):
1436
- if style_configs:
1437
- with st.spinner("Applying styles..."):
1438
- progress_bar = st.progress(0)
1439
- status_text = st.empty()
1440
-
1441
- # Process with progress updates
1442
- for i, (_, key, intensity) in enumerate(style_configs):
1443
- model_name = system.cyclegan_models[key]['name']
1444
- progress = (i + 1) / len(style_configs)
1445
- progress_bar.progress(progress)
1446
- status_text.text(f"Applying {model_name}...")
1447
-
1448
- result = system.blend_styles(input_image, style_configs, blend_mode)
1449
-
1450
- st.session_state['last_result'] = result
1451
- st.session_state['last_style_configs'] = style_configs
1452
- progress_bar.empty()
1453
- status_text.empty()
1454
-
1455
- with col2:
1456
- st.header("Result")
1457
- if 'last_result' in st.session_state:
1458
- st.image(st.session_state['last_result'], caption="Styled Image", use_column_width=True)
1459
-
1460
- # Download button
1461
- buf = io.BytesIO()
1462
- st.session_state['last_result'].save(buf, format='PNG')
1463
- st.download_button(
1464
- label="Download Result",
1465
- data=buf.getvalue(),
1466
- file_name=f"styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
1467
- mime="image/png"
1468
- )
1469
-
1470
- # TAB 2: Regional Transform
1471
- with tab2:
1472
- st.header("🖌️ Regional Style Transform")
1473
- st.markdown("Paint different regions to apply different styles locally")
1474
-
1475
- # Initialize session state
1476
- if 'regions' not in st.session_state:
1477
- st.session_state.regions = []
1478
- if 'canvas_results' not in st.session_state:
1479
- st.session_state.canvas_results = {}
1480
- if 'regional_image_original' not in st.session_state:
1481
- st.session_state.regional_image_original = None
1482
- if 'canvas_ready' not in st.session_state:
1483
- st.session_state.canvas_ready = True
1484
- if 'last_applied_regions' not in st.session_state:
1485
- st.session_state.last_applied_regions = None
1486
- if 'canvas_key_base' not in st.session_state:
1487
- st.session_state.canvas_key_base = 0
1488
-
1489
- col1, col2 = st.columns([2, 3])
1490
-
1491
- # Define variables at the top level of tab2
1492
- use_base = False
1493
- base_style = None
1494
- base_intensity = 1.0
1495
- regional_blend_mode = "additive"
1496
-
1497
- with col1:
1498
- # Image source selection
1499
- regional_image_source = st.radio("Image Source", ["Upload", "Unsplash"], horizontal=True, key="regional_image_source")
1500
-
1501
- if regional_image_source == "Upload":
1502
- uploaded_regional = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'], key="regional_upload")
1503
-
1504
- if uploaded_regional:
1505
- # Load and store original image
1506
- regional_image_original = Image.open(uploaded_regional).convert('RGB')
1507
- st.session_state.regional_image_original = regional_image_original
1508
- else:
1509
- # Use Unsplash image if available
1510
- if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
1511
- st.session_state.regional_image_original = st.session_state['current_image']
1512
- st.success("Using Unsplash image")
1513
- else:
1514
- st.info("Please search and select an image from the Style Transfer tab first")
1515
-
1516
- if st.session_state.regional_image_original:
1517
- # Display the original image
1518
- display_img = resize_image_for_display(st.session_state.regional_image_original, max_width=400, max_height=300)
1519
- st.image(display_img, caption="Original Image", use_column_width=True)
1520
-
1521
- st.subheader("Define Regions")
1522
-
1523
- # Base style (optional)
1524
- with st.expander("Base Style (Optional)", expanded=False):
1525
- use_base = st.checkbox("Apply base style to entire image")
1526
- if use_base:
1527
- base_style = st.selectbox("Base style", style_choices, key="base_style")
1528
- base_intensity = st.slider("Base intensity", 0.0, 2.0, 1.0, key="base_intensity")
1529
-
1530
- # Region management
1531
- col_btn1, col_btn2, col_btn3 = st.columns(3)
1532
- with col_btn1:
1533
- if st.button("➕ Add Region", use_container_width=True):
1534
- new_region = {
1535
- 'id': len(st.session_state.regions),
1536
- 'style': style_choices[0] if style_choices else None,
1537
- 'intensity': 1.0,
1538
- 'color': f"hsla({len(st.session_state.regions) * 60}, 70%, 50%, 0.5)"
1539
- }
1540
- st.session_state.regions.append(new_region)
1541
- st.session_state.canvas_ready = True
1542
- st.rerun()
1543
-
1544
- with col_btn2:
1545
- if st.button("🗑️ Clear All", use_container_width=True):
1546
- st.session_state.regions = []
1547
- st.session_state.canvas_results = {}
1548
- if 'regional_result' in st.session_state:
1549
- del st.session_state['regional_result']
1550
- st.session_state.canvas_ready = True
1551
- st.session_state.canvas_key_base = 0
1552
- st.rerun()
1553
-
1554
- with col_btn3:
1555
- if st.button("🔄 Reset Result", use_container_width=True):
1556
- if 'regional_result' in st.session_state:
1557
- del st.session_state['regional_result']
1558
- st.session_state.canvas_ready = True
1559
- st.rerun()
1560
-
1561
- # Configure each region
1562
- for i, region in enumerate(st.session_state.regions):
1563
- with st.expander(f"Region {i+1} - {region.get('style', 'None')}", expanded=(i == len(st.session_state.regions) - 1)):
1564
- col_a, col_b = st.columns(2)
1565
- with col_a:
1566
- new_style = st.selectbox(
1567
- "Style",
1568
- style_choices,
1569
- key=f"region_style_{i}",
1570
- index=style_choices.index(region['style']) if region['style'] in style_choices else 0
1571
- )
1572
- region['style'] = new_style
1573
- with col_b:
1574
- region['intensity'] = st.slider(
1575
- "Intensity",
1576
- 0.0, 2.0,
1577
- region.get('intensity', 1.0),
1578
- key=f"region_intensity_{i}"
1579
- )
1580
-
1581
- if st.button(f"🗑️ Remove Region {i+1}", key=f"remove_region_{i}"):
1582
- # Remove the region
1583
- st.session_state.regions.pop(i)
1584
-
1585
- # Rebuild canvas results with proper indices
1586
- old_canvas_results = st.session_state.canvas_results.copy()
1587
- st.session_state.canvas_results = {}
1588
-
1589
- for old_idx, result in old_canvas_results.items():
1590
- if old_idx < i:
1591
- # Keep results before removed index
1592
- st.session_state.canvas_results[old_idx] = result
1593
- elif old_idx > i:
1594
- # Shift results after removed index down by 1
1595
- st.session_state.canvas_results[old_idx - 1] = result
1596
-
1597
- st.session_state.canvas_ready = True
1598
- st.session_state.canvas_key_base += 1
1599
- st.rerun()
1600
-
1601
- # Blend mode
1602
- regional_blend_mode = st.selectbox("Blend Mode",
1603
- ["additive", "average", "maximum", "overlay", "screen"],
1604
- index=0, key="regional_blend")
1605
-
1606
- with col2:
1607
- if st.session_state.regions and st.session_state.regional_image_original:
1608
- st.subheader("Paint Regions")
1609
-
1610
- # Show workflow status
1611
- if 'regional_result' in st.session_state:
1612
- if st.session_state.canvas_ready:
1613
- st.success("✏️ **Edit Mode** - Paint your regions and click 'Apply Regional Styles' when ready")
1614
- else:
1615
- st.info("👁️ **Preview Mode** - Click 'Continue Editing' to modify regions")
1616
- else:
1617
- st.info("✏️ Paint on the canvas below to define regions for each style")
1618
-
1619
- # Check if we're in edit mode
1620
- if not st.session_state.canvas_ready:
1621
- # Show a preview of the painted regions
1622
- if 'regional_result' in st.session_state:
1623
- st.subheader("Current Result")
1624
- result_display = resize_image_for_display(st.session_state['regional_result'], max_width=600, max_height=400)
1625
- st.image(result_display, caption="Applied Styles", use_column_width=True)
1626
-
1627
- # Create display image
1628
- display_image = resize_image_for_display(st.session_state.regional_image_original, max_width=600, max_height=400)
1629
- display_width, display_height = display_image.size
1630
-
1631
- # Info message
1632
- st.info(f"💡 Image resized to {display_width}x{display_height} for display. Original resolution will be used for processing.")
1633
-
1634
- # Get current region
1635
- current_region_idx = st.selectbox(
1636
- "Select region to paint",
1637
- range(len(st.session_state.regions)),
1638
- format_func=lambda x: f"Region {x+1}: {st.session_state.regions[x].get('style', 'None')}"
1639
- )
1640
-
1641
- current_region = st.session_state.regions[current_region_idx]
1642
-
1643
- # THIS IS THE FIX: The following line was added.
1644
- col_draw1, col_draw2, col_draw3 = st.columns(3)
1645
-
1646
- with col_draw1:
1647
- brush_size = st.slider("Brush Size", 1, 50, 15)
1648
- with col_draw2:
1649
- drawing_mode = st.selectbox("Tool", ["freedraw", "line", "rect", "circle"])
1650
- with col_draw3:
1651
- if st.button("Clear This Region"):
1652
- if current_region_idx in st.session_state.canvas_results:
1653
- del st.session_state.canvas_results[current_region_idx]
1654
- st.session_state.canvas_ready = True
1655
- st.rerun()
1656
-
1657
- # Create combined background with all previous regions
1658
- background_with_regions = display_image.copy()
1659
- draw = ImageDraw.Draw(background_with_regions, 'RGBA')
1660
-
1661
- # Draw all regions on the background
1662
- for i, region in enumerate(st.session_state.regions):
1663
- if i in st.session_state.canvas_results:
1664
- canvas_data = st.session_state.canvas_results[i]
1665
- if canvas_data is not None and hasattr(canvas_data, 'image_data') and canvas_data.image_data is not None:
1666
- # Extract mask from canvas data
1667
- mask = canvas_data.image_data[:, :, 3] > 0
1668
-
1669
- # Create colored overlay for this region
1670
- # Parse HSLA color more carefully
1671
- color_str = region['color'].replace('hsla(', '').replace(')', '')
1672
- color_parts = color_str.split(',')
1673
- hue = int(color_parts[0])
1674
- # Convert HSL to RGB (simplified - assumes 70% saturation, 50% lightness)
1675
- r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 0.7)
1676
- color = (int(r*255), int(g*255), int(b*255))
1677
- opacity = 128 if i != current_region_idx else 200
1678
-
1679
- # Draw mask on background
1680
- for y in range(mask.shape[0]):
1681
- for x in range(mask.shape[1]):
1682
- if mask[y, x]:
1683
- draw.point((x, y), fill=color + (opacity,))
1684
-
1685
- # Canvas for current region
1686
- stroke_color = current_region['color'].replace('0.5)', '0.8)')
1687
-
1688
- # Get initial drawing for current region
1689
- initial_drawing = None
1690
- if current_region_idx in st.session_state.canvas_results:
1691
- canvas_data = st.session_state.canvas_results[current_region_idx]
1692
- if canvas_data is not None and hasattr(canvas_data, 'json_data'):
1693
- initial_drawing = canvas_data.json_data
1694
-
1695
- canvas_result = st_canvas(
1696
- fill_color=stroke_color,
1697
- stroke_width=brush_size,
1698
- stroke_color=stroke_color,
1699
- background_image=background_with_regions,
1700
- update_streamlit=True,
1701
- height=display_height,
1702
- width=display_width,
1703
- drawing_mode=drawing_mode,
1704
- display_toolbar=True,
1705
- initial_drawing=initial_drawing,
1706
- key=f"regional_canvas_{current_region_idx}_{brush_size}_{drawing_mode}"
1707
- )
1708
-
1709
- # Save canvas result
1710
- if canvas_result:
1711
- st.session_state.canvas_results[current_region_idx] = canvas_result
1712
-
1713
- # Apply button
1714
- if st.button("Apply Regional Styles", type="primary", use_container_width=True):
1715
- with st.spinner("Applying regional styles..."):
1716
- # Create combined mask from all canvas results
1717
- combined_mask = combine_region_masks(
1718
- [st.session_state.canvas_results.get(i) for i in range(len(st.session_state.regions))],
1719
- (display_height, display_width)
1720
- )
1721
-
1722
- # Prepare base style configs if enabled
1723
- base_configs = None
1724
- if use_base and base_style:
1725
- base_key = None
1726
- for key, info in system.cyclegan_models.items():
1727
- if info['name'] == base_style:
1728
- base_key = key
1729
- break
1730
- if base_key:
1731
- base_configs = [('cyclegan', base_key, base_intensity)]
1732
-
1733
- # Apply regional styles using original image
1734
- result = system.apply_regional_styles(
1735
- st.session_state.regional_image_original, # Use original resolution
1736
- combined_mask,
1737
- st.session_state.regions,
1738
- base_configs,
1739
- regional_blend_mode
1740
- )
1741
-
1742
- st.session_state['regional_result'] = result
1743
-
1744
- # Show result
1745
- if 'regional_result' in st.session_state:
1746
- st.subheader("Result")
1747
- st.image(st.session_state['regional_result'], caption="Regional Styled Image", use_column_width=True)
1748
-
1749
- # Download button
1750
- buf = io.BytesIO()
1751
- st.session_state['regional_result'].save(buf, format='PNG')
1752
- st.download_button(
1753
- label="Download Result",
1754
- data=buf.getvalue(),
1755
- file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
1756
- mime="image/png"
1757
- )
1758
-
1759
- # TAB 3: Video Processing
1760
- with tab3:
1761
- st.header("🎬 Video Processing")
1762
-
1763
- if not VIDEO_PROCESSING_AVAILABLE:
1764
- st.warning("""
1765
- ⚠️ Video processing requires OpenCV to be installed.
1766
-
1767
- To enable video processing, add `opencv-python` to your requirements.txt
1768
- """)
1769
- else:
1770
- col1, col2 = st.columns(2)
1771
-
1772
- with col1:
1773
- video_file = st.file_uploader("Upload Video", type=['mp4', 'avi', 'mov'])
1774
-
1775
- if video_file:
1776
- st.video(video_file)
1777
-
1778
- st.subheader("Style Configuration")
1779
 
1780
- # Style selection (up to 2 for videos)
1781
- video_styles = []
1782
- for i in range(2):
1783
- with st.expander(f"Style {i+1}", expanded=(i==0)):
1784
- style = st.selectbox(f"Select style", style_choices, key=f"video_style_{i}")
1785
- intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"video_intensity_{i}")
1786
- if style and intensity > 0:
1787
- model_key = None
1788
- for key, info in system.cyclegan_models.items():
1789
- if info['name'] == style:
1790
- model_key = key
1791
- break
1792
- if model_key:
1793
- video_styles.append(('cyclegan', model_key, intensity))
1794
 
1795
- video_blend_mode = st.selectbox("Blend Mode",
1796
- ["additive", "average", "maximum", "overlay", "screen"],
1797
- index=0, key="video_blend")
 
 
 
1798
 
1799
- if st.button("Process Video", type="primary", use_container_width=True):
1800
- if video_styles:
1801
- with st.spinner("Processing video..."):
1802
- progress_bar = st.progress(0)
1803
- status_text = st.empty()
1804
-
1805
- def progress_callback(p, msg):
1806
- progress_bar.progress(p)
1807
- status_text.text(msg)
1808
-
1809
- # Save uploaded file temporarily
1810
- temp_input = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
1811
- temp_input.write(video_file.read())
1812
- temp_input.close()
1813
-
1814
- # Process video
1815
- output_path = system.video_processor.process_video(
1816
- temp_input.name, video_styles, video_blend_mode, progress_callback
1817
- )
1818
-
1819
- if output_path:
1820
- st.session_state['video_result'] = output_path
1821
-
1822
- # Cleanup
1823
- os.unlink(temp_input.name)
1824
- progress_bar.empty()
1825
- status_text.empty()
1826
-
1827
- with col2:
1828
- st.header("Result")
1829
- if 'video_result' in st.session_state and os.path.exists(st.session_state['video_result']):
1830
- # Try to display video
1831
- try:
1832
- st.video(st.session_state['video_result'])
1833
- except:
1834
- st.warning("Cannot display video in browser. Use download button below.")
1835
 
1836
- # Download button
1837
- with open(st.session_state['video_result'], 'rb') as f:
1838
- st.download_button(
1839
- label="Download Processed Video",
1840
- data=f.read(),
1841
- file_name=f"styled_video_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4",
1842
- mime="video/mp4"
1843
- )
1844
-
1845
- # TAB 4: Training
1846
- with tab4:
1847
- st.header("🔧 Train Custom Style")
1848
- st.markdown("Train your own lightweight style transfer model")
1849
-
1850
- col1, col2 = st.columns(2)
1851
-
1852
- with col1:
1853
- style_img = st.file_uploader("Style Image", type=['png', 'jpg', 'jpeg'], key="train_style")
1854
- content_imgs = st.file_uploader("Content Images (1-50)", type=['png', 'jpg', 'jpeg'],
1855
- accept_multiple_files=True, key="train_content")
1856
-
1857
- if style_img:
1858
- st.image(Image.open(style_img), caption="Style Image", use_column_width=True)
1859
-
1860
- model_name = st.text_input("Model Name", value=f"custom_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
1861
-
1862
- col_a, col_b = st.columns(2)
1863
- with col_a:
1864
- epochs = st.slider("Training Epochs", 10, 100, 30, 5)
1865
- batch_size = st.slider("Batch Size", 1, 8, 4)
1866
- with col_b:
1867
- learning_rate = st.number_input("Learning Rate", 0.0001, 0.01, 0.001, format="%.4f")
1868
- save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20, 5, 5)
1869
-
1870
- with st.expander("Advanced Settings"):
1871
- style_weight = st.number_input("Style Weight", 1e3, 1e6, 1e5, step=1e3, format="%.0f")
1872
- content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1)
1873
- res_blocks = st.slider("Residual Blocks", 3, 12, 5)
1874
-
1875
- if st.button("Start Training", type="primary", use_container_width=True):
1876
- if style_img and content_imgs:
1877
- with st.spinner("Training..."):
1878
- progress_bar = st.progress(0)
1879
- status_text = st.empty()
1880
-
1881
- def progress_callback(p, msg):
1882
- progress_bar.progress(p)
1883
- status_text.text(msg)
1884
-
1885
- # Create temp directory for content images
1886
- temp_content_dir = f'/tmp/content_images_{uuid.uuid4().hex}'
1887
- os.makedirs(temp_content_dir, exist_ok=True)
1888
-
1889
- # Save content images
1890
- for idx, img_file in enumerate(content_imgs):
1891
- img = Image.open(img_file).convert('RGB')
1892
- img.save(os.path.join(temp_content_dir, f'content_{idx}.jpg'))
1893
-
1894
- # Train model
1895
- style_image = Image.open(style_img).convert('RGB')
1896
- model = system.train_lightweight_model(
1897
- style_image, temp_content_dir, model_name,
1898
- epochs=epochs, lr=learning_rate, batch_size=batch_size,
1899
- save_interval=save_interval, style_weight=style_weight,
1900
- content_weight=content_weight, n_residual_blocks=res_blocks,
1901
- progress_callback=progress_callback
1902
- )
1903
-
1904
- # Cleanup
1905
- shutil.rmtree(temp_content_dir)
1906
-
1907
- if model:
1908
- st.session_state['trained_model'] = model
1909
- st.session_state['model_path'] = f'/tmp/trained_models/{model_name}_final.pth'
1910
- st.success("Training complete!")
1911
-
1912
- progress_bar.empty()
1913
- status_text.empty()
1914
-
1915
- with col2:
1916
- if 'trained_model' in st.session_state:
1917
- st.header("Test Your Model")
1918
- test_img = st.file_uploader("Test Image", type=['png', 'jpg', 'jpeg'], key="test_trained")
1919
-
1920
- if test_img:
1921
- test_image = Image.open(test_img).convert('RGB')
1922
- col_before, col_after = st.columns(2)
1923
 
1924
- with col_before:
1925
- st.image(test_image, caption="Original", use_column_width=True)
1926
 
1927
- with col_after:
1928
- result = system.apply_lightweight_style(test_image, st.session_state['trained_model'])
1929
- if result:
1930
- st.image(result, caption="Styled", use_column_width=True)
1931
-
1932
- # Download model
1933
- if 'model_path' in st.session_state and os.path.exists(st.session_state['model_path']):
1934
- with open(st.session_state['model_path'], 'rb') as f:
1935
- st.download_button(
1936
- label="Download Trained Model",
1937
- data=f.read(),
1938
- file_name=f"{model_name}_final.pth",
1939
- mime="application/octet-stream"
1940
- )
1941
-
1942
- # TAB 5: Batch Processing
1943
- with tab5:
1944
- st.header("📦 Batch Processing")
1945
-
1946
- col1, col2 = st.columns(2)
1947
-
1948
- with col1:
1949
- # Image source selection for batch
1950
- batch_source = st.radio("Image Source", ["Upload Multiple", "Use Current Unsplash Image"], horizontal=True, key="batch_source")
1951
-
1952
- batch_files = []
1953
- if batch_source == "Upload Multiple":
1954
- batch_files = st.file_uploader("Upload Images", type=['png', 'jpg', 'jpeg'],
1955
- accept_multiple_files=True, key="batch_upload")
1956
- else:
1957
- # Use current Unsplash image if available
1958
- if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
1959
- batch_files = [st.session_state['current_image']]
1960
- st.success("Using current Unsplash image for batch processing")
1961
- else:
1962
- st.info("Please search and select an image from the Style Transfer tab first")
1963
-
1964
- processing_type = st.radio("Processing Type", ["CycleGAN", "Custom Trained Model"])
1965
-
1966
- if processing_type == "CycleGAN":
1967
- # Style configuration
1968
- batch_styles = []
1969
- for i in range(3):
1970
- with st.expander(f"Style {i+1}", expanded=(i==0)):
1971
- style = st.selectbox(f"Select style", style_choices, key=f"batch_style_{i}")
1972
- intensity = st.slider(f"Intensity", 0.0, 2.0, 1.0, 0.1, key=f"batch_intensity_{i}")
1973
- if style and intensity > 0:
1974
- model_key = None
1975
- for key, info in system.cyclegan_models.items():
1976
- if info['name'] == style:
1977
- model_key = key
1978
- break
1979
- if model_key:
1980
- batch_styles.append(('cyclegan', model_key, intensity))
1981
 
1982
- batch_blend_mode = st.selectbox("Blend Mode",
1983
- ["additive", "average", "maximum", "overlay", "screen"],
1984
- index=0, key="batch_blend")
1985
- else:
1986
- # Custom model upload
1987
- custom_model_file = st.file_uploader("Upload Trained Model (.pth)", type=['pth'])
1988
-
1989
- if st.button("Process Batch", type="primary", use_container_width=True):
1990
- if batch_files:
1991
- with st.spinner("Processing batch..."):
1992
- progress_bar = st.progress(0)
1993
- processed_images = []
1994
-
1995
- if processing_type == "CycleGAN" and batch_styles:
1996
- for idx, file in enumerate(batch_files):
1997
- progress_bar.progress((idx + 1) / len(batch_files))
1998
- # Handle both file uploads and PIL images
1999
- if isinstance(file, Image.Image):
2000
- image = file
2001
- else:
2002
- image = Image.open(file).convert('RGB')
2003
- result = system.blend_styles(image, batch_styles, batch_blend_mode)
2004
- processed_images.append(result)
2005
-
2006
- elif processing_type == "Custom Trained Model" and custom_model_file:
2007
- # Load custom model
2008
- temp_model = tempfile.NamedTemporaryFile(delete=False, suffix='.pth')
2009
- temp_model.write(custom_model_file.read())
2010
- temp_model.close()
2011
-
2012
- model = system.load_lightweight_model(temp_model.name)
2013
-
2014
- if model:
2015
- for idx, file in enumerate(batch_files):
2016
- progress_bar.progress((idx + 1) / len(batch_files))
2017
- # Handle both file uploads and PIL images
2018
- if isinstance(file, Image.Image):
2019
- image = file
2020
- else:
2021
- image = Image.open(file).convert('RGB')
2022
- result = system.apply_lightweight_style(image, model)
2023
- if result:
2024
- processed_images.append(result)
2025
-
2026
- os.unlink(temp_model.name)
2027
-
2028
- if processed_images:
2029
- # Create zip
2030
- zip_buffer = io.BytesIO()
2031
- with zipfile.ZipFile(zip_buffer, 'w') as zf:
2032
- for idx, img in enumerate(processed_images):
2033
- img_buffer = io.BytesIO()
2034
- img.save(img_buffer, format='PNG')
2035
- zf.writestr(f"styled_{idx+1:03d}.png", img_buffer.getvalue())
2036
-
2037
- st.session_state['batch_results'] = processed_images
2038
- st.session_state['batch_zip'] = zip_buffer.getvalue()
2039
-
2040
- progress_bar.empty()
2041
-
2042
- with col2:
2043
- if 'batch_results' in st.session_state:
2044
- st.header("Results")
2045
-
2046
- # Show gallery
2047
- cols = st.columns(4)
2048
- for idx, img in enumerate(st.session_state['batch_results'][:8]):
2049
- cols[idx % 4].image(img, use_column_width=True)
2050
-
2051
- if len(st.session_state['batch_results']) > 8:
2052
- st.info(f"Showing 8 of {len(st.session_state['batch_results'])} processed images")
2053
-
2054
- # Download zip
2055
- st.download_button(
2056
- label="Download All (ZIP)",
2057
- data=st.session_state['batch_zip'],
2058
- file_name=f"batch_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.zip",
2059
- mime="application/zip"
2060
- )
2061
-
2062
- # TAB 6: Documentation
2063
- with tab6:
2064
- st.markdown(f"""
2065
- ## Style Transfer System Documentation
2066
-
2067
- ### Available CycleGAN Models
2068
-
2069
- This system includes pre-trained bidirectional CycleGAN models:
2070
- {chr(10).join([f'- **{info["name"]}**' for key, info in sorted(system.cyclegan_models.items(), key=lambda item: item[1]["name"])])}
2071
-
2072
- ### Features
2073
-
2074
- #### 🎨 Style Transfer
2075
- - Apply multiple styles simultaneously
2076
- - Adjustable intensity for each style
2077
- - Multiple blending modes for creative effects
2078
- - **NEW**: Search and use images from Unsplash
2079
-
2080
- #### 🖌️ Regional Transform
2081
- - Paint specific regions to apply different styles
2082
- - Support for multiple regions with different styles
2083
- - Adjustable brush size and drawing tools
2084
- - Base style + regional overlays
2085
- - Persistent brush strokes across regions
2086
- - Optimized display for large images
2087
-
2088
- #### 🎬 Video Processing
2089
- - Frame-by-frame style transfer
2090
- - Maintains temporal consistency
2091
- - Supports all style combinations and blend modes
2092
- - Enhanced codec compatibility
2093
-
2094
- #### 🔧 Custom Training
2095
- - Train on any artistic style with minimal data (1-50 images)
2096
- - Automatic data augmentation for small datasets
2097
- - Adjustable model complexity (3-12 residual blocks)
2098
-
2099
- ### Model Architecture
2100
-
2101
- - **CycleGAN models**: 9-12 residual blocks for high-quality transformations
2102
- - **Lightweight models**: 3-12 residual blocks (customizable during training)
2103
- - **Training approach**: Unpaired image-to-image translation
2104
-
2105
- ### Technical Details
2106
-
2107
- - **Framework**: PyTorch
2108
- - **GPU Support**: CUDA acceleration when available
2109
- - **Image Formats**: JPG, PNG, BMP
2110
- - **Video Formats**: MP4, AVI, MOV
2111
- - **Model Size**: ~45MB (CycleGAN), 5-15MB (Lightweight)
2112
-
2113
- ### Unsplash Integration
2114
-
2115
- To use Unsplash image search:
2116
- 1. Get a free API key from [Unsplash Developers](https://unsplash.com/developers)
2117
- 2. Add it to your HuggingFace Space secrets as `UNSPLASH_ACCESS_KEY`
2118
- 3. Search for images directly in the app
2119
- 4. Automatic attribution for photographers
2120
-
2121
- ### Usage Tips
2122
-
2123
- 1. **For best results**: Use high-quality input images
2124
- 2. **Style intensity**: Start with 1.0, adjust to taste
2125
- 3. **Blending modes**:
2126
- - 'Additive' for bold effects
2127
- - 'Average' for subtle blends
2128
- - 'Overlay' for dramatic contrasts
2129
- 4. **Regional painting**:
2130
- - Use larger brush for smooth transitions
2131
- - Multiple thin layers work better than one thick layer
2132
- - Previous regions remain visible as you paint new ones
2133
- 5. **Custom training**: More diverse content images = better generalization
2134
- 6. **Video processing**: Keep videos under 30 seconds for faster processing
2135
-
2136
- ### Regional Transform Guide
2137
-
2138
- The regional transform feature allows you to:
2139
- 1. Define multiple regions by painting on the canvas
2140
- 2. Assign different styles to each region
2141
- 3. Control intensity per region
2142
- 4. Apply an optional base style to the entire image
2143
- 5. Blend regions using various modes
2144
-
2145
- **Tips for Regional Transform:**
2146
- - Start with a base style for overall coherence
2147
- - Use semi-transparent brushes for smoother transitions
2148
- - Overlap regions for interesting blend effects
2149
- - Experiment with different blend modes per region
2150
- - All regions are visible while painting for better control
2151
- """)
2152
-
2153
- # Footer
2154
- st.markdown("---")
2155
- st.markdown("Professional style transfer system with state-of-the-art CycleGAN models and regional painting capabilities.")
 
1
+ def process_video(self, video_path, style_configs, blend_mode, progress_callback=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  """Process a video file with style transfer"""
3
  if not VIDEO_PROCESSING_AVAILABLE:
4
  print("Video processing requires OpenCV (cv2) - please install it")
 
16
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
17
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
18
 
19
+ # Create temporary output file - always use mp4 for web compatibility
20
  temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
21
  temp_output.close() # Close so OpenCV can write
22
 
23
+ # Use mp4v codec which has better compatibility
24
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
25
+ out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
 
26
 
27
+ if not out.isOpened():
28
+ # Try H264 as fallback
29
+ fourcc = cv2.VideoWriter_fourcc(*'H264')
30
+ out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
 
 
 
 
 
 
31
 
32
+ if not out.isOpened():
33
+ # Last resort - use system default
34
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
 
 
35
  out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
36
 
37
  if not out.isOpened():
 
64
  cap.release()
65
  out.release()
66
 
67
+ # Always ensure the output is a proper MP4 file
68
+ # OpenCV sometimes creates files that aren't properly formatted for web
69
+ final_output = tempfile.NamedTemporaryFile(suffix='_final.mp4', delete=False)
70
+ final_output.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # Use ffmpeg-python if available, or try OpenCV re-encoding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  try:
74
+ # Try to re-encode with OpenCV for better compatibility
75
+ cap = cv2.VideoCapture(temp_output.name)
76
+
77
+ # Get the best codec for web compatibility
78
+ # Try codecs in order of compatibility
79
+ web_codecs = [
80
+ cv2.VideoWriter_fourcc(*'avc1'), # H.264 variant
81
+ cv2.VideoWriter_fourcc(*'H264'), # H.264
82
+ cv2.VideoWriter_fourcc(*'mp4v'), # MPEG-4
83
+ ]
84
+
85
+ out = None
86
+ for codec in web_codecs:
87
+ try:
88
+ test_out = cv2.VideoWriter(final_output.name, codec, fps, (width, height))
89
+ if test_out.isOpened():
90
+ out = test_out
91
+ print(f"Using web-compatible codec: {codec}")
92
+ break
93
+ else:
94
+ test_out.release()
95
+ except:
96
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ if out is None:
99
+ # If no web codec works, use the original file
100
+ cap.release()
101
+ os.unlink(final_output.name)
102
+ return temp_output.name
 
 
 
 
 
 
 
 
 
103
 
104
+ # Re-encode the video
105
+ while True:
106
+ ret, frame = cap.read()
107
+ if not ret:
108
+ break
109
+ out.write(frame)
110
 
111
+ cap.release()
112
+ out.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ # Clean up temp file
115
+ os.unlink(temp_output.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ return final_output.name
 
118
 
119
+ except Exception as e:
120
+ print(f"Re-encoding failed: {e}, using original file")
121
+ os.unlink(final_output.name)
122
+ return temp_output.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ except Exception as e:
125
+ print(f"Error processing video: {e}")
126
+ traceback.print_exc()
127
+ return None