dannyroxas commited on
Commit
4591dfb
·
verified ·
1 Parent(s): 578b277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +445 -125
app.py CHANGED
@@ -46,7 +46,7 @@ warnings.filterwarnings("ignore")
46
  # Set page config
47
  st.set_page_config(
48
  page_title="Style Transfer Studio",
49
- # page_icon="🎨",
50
  layout="wide",
51
  initial_sidebar_state="expanded"
52
  )
@@ -1344,6 +1344,56 @@ class StyleTransferSystem:
1344
  return model
1345
  except:
1346
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1347
 
1348
  def apply_lightweight_style(self, image, model, intensity=1.0):
1349
  """Apply style using a lightweight model"""
@@ -1654,8 +1704,18 @@ class StyleTransferSystem:
1654
 
1655
  return model
1656
 
1657
- def apply_adain_style(self, content_image, style_image, model, alpha=1.0):
1658
- """Apply AdaIN-based style transfer with better quality preservation"""
 
 
 
 
 
 
 
 
 
 
1659
  if content_image is None or style_image is None or model is None:
1660
  return None
1661
 
@@ -1665,32 +1725,9 @@ class StyleTransferSystem:
1665
 
1666
  original_size = content_image.size
1667
 
1668
- # Determine processing size based on GPU memory
1669
- if self.device.type == 'cuda':
1670
- # Try to use higher resolution on GPU
1671
- max_dimension = 768 # Increased from 256
1672
-
1673
- # Calculate memory-efficient size
1674
- aspect_ratio = content_image.width / content_image.height
1675
- if content_image.width > content_image.height:
1676
- process_width = min(max_dimension, content_image.width)
1677
- process_height = int(process_width / aspect_ratio)
1678
- else:
1679
- process_height = min(max_dimension, content_image.height)
1680
- process_width = int(process_height * aspect_ratio)
1681
-
1682
- # Round to multiples of 32 for better GPU efficiency
1683
- process_width = ((process_width + 31) // 32) * 32
1684
- process_height = ((process_height + 31) // 32) * 32
1685
- else:
1686
- # Lower resolution for CPU
1687
- process_width, process_height = 512, 512
1688
-
1689
- print(f"Processing at {process_width}x{process_height} (original: {original_size})")
1690
-
1691
- # Transform without aggressive cropping
1692
  transform = transforms.Compose([
1693
- transforms.Resize((process_height, process_width)),
1694
  transforms.ToTensor(),
1695
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1696
  std=[0.229, 0.224, 0.225])
@@ -1700,71 +1737,117 @@ class StyleTransferSystem:
1700
  style_tensor = transform(style_image).unsqueeze(0).to(self.device)
1701
 
1702
  with torch.no_grad():
1703
- # Process in chunks if image is very large
1704
- if process_width * process_height > 512 * 512 and self.device.type == 'cuda':
1705
- # Clear cache before processing large image
1706
- torch.cuda.empty_cache()
1707
-
1708
  output = model(content_tensor, style_tensor, alpha=alpha)
1709
 
1710
- # Improved denormalization
1711
  output = output.squeeze(0).cpu()
1712
-
1713
- # More precise denormalization
1714
- denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
1715
- denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
1716
- output = output * denorm_std + denorm_mean
1717
-
1718
- # Ensure values are in valid range
1719
  output = torch.clamp(output, 0, 1)
1720
 
1721
- # Convert to PIL with better quality
1722
  output_img = transforms.ToPILImage()(output)
1723
-
1724
- # Use LANCZOS for high-quality resizing
1725
- if output_img.size != original_size:
1726
- output_img = output_img.resize(original_size, Image.LANCZOS)
1727
-
1728
- # Optional: Sharpen slightly to compensate for softness
1729
- if hasattr(Image, 'UnsharpMask'):
1730
- from PIL import ImageFilter
1731
- output_img = output_img.filter(ImageFilter.UnsharpMask(radius=1, percent=50, threshold=0))
1732
 
1733
  return output_img
1734
 
1735
- except RuntimeError as e:
1736
- if "out of memory" in str(e):
1737
- print("GPU out of memory, falling back to lower resolution...")
1738
- torch.cuda.empty_cache()
1739
-
1740
- # Fallback to lower resolution
1741
- transform = transforms.Compose([
1742
- transforms.Resize((384, 384)), # Still better than 256
1743
- transforms.ToTensor(),
1744
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
1745
- std=[0.229, 0.224, 0.225])
1746
- ])
1747
-
1748
- content_tensor = transform(content_image).unsqueeze(0).to(self.device)
1749
- style_tensor = transform(style_image).unsqueeze(0).to(self.device)
1750
-
1751
- with torch.no_grad():
1752
- output = model(content_tensor, style_tensor, alpha=alpha)
1753
- output = output.squeeze(0).cpu()
1754
-
1755
- denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
1756
- denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
1757
- output = output * denorm_std + denorm_mean
1758
- output = torch.clamp(output, 0, 1)
1759
-
1760
- output_img = transforms.ToPILImage()(output)
1761
- output_img = output_img.resize(original_size, Image.LANCZOS)
1762
-
1763
- return output_img
1764
- else:
1765
- print(f"Error applying AdaIN style: {e}")
1766
- traceback.print_exc()
1767
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1768
 
1769
 
1770
  # ===========================
@@ -1801,6 +1884,64 @@ def combine_region_masks(canvas_results, canvas_size):
1801
 
1802
  return combined_mask
1803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1804
  # ===========================
1805
  # INITIALIZE SYSTEM AND API
1806
  # ===========================
@@ -1866,7 +2007,7 @@ with st.sidebar:
1866
  st.caption("For faster processing, use a GPU-enabled environment")
1867
 
1868
  st.markdown("---")
1869
- st.markdown("### 📚 Quick Guide")
1870
  st.markdown("""
1871
  - **Style Transfer**: Apply artistic styles to images
1872
  - **Regional Transform**: Paint areas for local effects
@@ -2136,7 +2277,7 @@ with tab2:
2136
  # Region management
2137
  col_btn1, col_btn2, col_btn3 = st.columns(3)
2138
  with col_btn1:
2139
- if st.button("Add Region", use_container_width=True):
2140
  new_region = {
2141
  'id': len(st.session_state.regions),
2142
  'style': style_choices[0] if style_choices else None,
@@ -2184,7 +2325,7 @@ with tab2:
2184
  key=f"region_intensity_{i}"
2185
  )
2186
 
2187
- if st.button(f"🗑️ Remove Region {i+1}", key=f"remove_region_{i}"):
2188
  # Remove the region
2189
  st.session_state.regions.pop(i)
2190
 
@@ -2216,9 +2357,9 @@ with tab2:
2216
  # Show workflow status
2217
  if 'regional_result' in st.session_state:
2218
  if st.session_state.canvas_ready:
2219
- st.success("**Edit Mode** - Paint your regions and click 'Apply Regional Styles' when ready")
2220
  else:
2221
- st.info("**Preview Mode** - Click 'Continue Editing' to modify regions")
2222
  else:
2223
  st.info("Paint on the canvas below to define regions for each style")
2224
 
@@ -2246,7 +2387,6 @@ with tab2:
2246
 
2247
  current_region = st.session_state.regions[current_region_idx]
2248
 
2249
- # THIS IS THE FIX: The following line was added.
2250
  col_draw1, col_draw2, col_draw3 = st.columns(3)
2251
 
2252
  with col_draw1:
@@ -2347,10 +2487,23 @@ with tab2:
2347
 
2348
  st.session_state['regional_result'] = result
2349
 
2350
- # Show result
2351
  if 'regional_result' in st.session_state:
2352
  st.subheader("Result")
2353
- st.image(st.session_state['regional_result'], caption="Regional Styled Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
2354
 
2355
  # Download button
2356
  buf = io.BytesIO()
@@ -2361,10 +2514,7 @@ with tab2:
2361
  file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
2362
  mime="image/png"
2363
  )
2364
-
2365
- # TAB 3: Video Processing
2366
- # TAB 3: Video Processing
2367
- # TAB 3: Video Processing
2368
  # TAB 3: Video Processing
2369
  with tab3:
2370
  st.header("Video Processing")
@@ -2555,7 +2705,7 @@ with tab3:
2555
 
2556
 
2557
 
2558
- # TAB 4: Training with AdaIN
2559
  with tab4:
2560
  st.header("Train Custom Style with AdaIN")
2561
  st.markdown("Train your own style transfer model using Adaptive Instance Normalization")
@@ -2563,6 +2713,10 @@ with tab4:
2563
  # Initialize session state for content images
2564
  if 'content_images_list' not in st.session_state:
2565
  st.session_state.content_images_list = []
 
 
 
 
2566
 
2567
  col1, col2, col3 = st.columns([1, 1, 1])
2568
 
@@ -2603,7 +2757,7 @@ with tab4:
2603
  st.caption(f"... and {len(content_imgs) - 3} more")
2604
 
2605
  with col3:
2606
- st.subheader("⚙️ Training Settings")
2607
 
2608
  model_name = st.text_input("Model Name",
2609
  value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
@@ -2664,18 +2818,24 @@ with tab4:
2664
  st.session_state['trained_adain_model'] = model
2665
  st.session_state['trained_style_images'] = style_images
2666
  st.session_state['model_path'] = f'/tmp/trained_models/{model_name}_final.pth'
2667
- st.success("AdaIN training complete!")
2668
 
2669
  progress_bar.empty()
2670
  status_text.empty()
2671
  else:
2672
  st.error("Please upload both style and content images")
2673
 
2674
- # Testing section
2675
  if 'trained_adain_model' in st.session_state:
2676
  st.markdown("---")
2677
  st.header("Test Your AdaIN Model")
2678
 
 
 
 
 
 
 
2679
  test_col1, test_col2, test_col3 = st.columns([1, 1, 1])
2680
 
2681
  with test_col1:
@@ -2683,7 +2843,7 @@ with tab4:
2683
 
2684
  # Test image selection
2685
  test_source = st.radio("Test Image Source",
2686
- ["Use Content Image", "Upload New"],
2687
  horizontal=True)
2688
 
2689
  test_image = None
@@ -2693,6 +2853,13 @@ with tab4:
2693
  range(len(st.session_state.content_images_list)),
2694
  format_func=lambda x: f"Content Image {x+1}")
2695
  test_image = Image.open(st.session_state.content_images_list[content_idx]).convert('RGB')
 
 
 
 
 
 
 
2696
  else:
2697
  # Upload new image
2698
  test_upload = st.file_uploader("Upload test image",
@@ -2701,6 +2868,10 @@ with tab4:
2701
  if test_upload:
2702
  test_image = Image.open(test_upload).convert('RGB')
2703
 
 
 
 
 
2704
  # Style selection for testing
2705
  if 'trained_style_images' in st.session_state and len(st.session_state['trained_style_images']) > 1:
2706
  style_idx = st.selectbox("Select style",
@@ -2716,36 +2887,125 @@ with tab4:
2716
  # Alpha blending control
2717
  alpha = st.slider("Style Strength (Alpha)", 0.0, 1.0, 1.0, 0.1,
2718
  help="0 = original content, 1 = full style transfer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2719
 
2720
  with test_col2:
2721
- st.subheader("Original")
2722
- if test_image:
2723
- st.image(test_image, caption="Content Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2724
  if test_style:
 
2725
  st.image(test_style, caption="Style Image", use_column_width=True)
 
 
 
 
 
 
 
2726
 
2727
  with test_col3:
2728
  st.subheader("Result")
2729
- if test_image and test_style:
 
 
 
 
2730
  with st.spinner("Applying style..."):
2731
- result = system.apply_adain_style(
2732
- test_image,
2733
- test_style,
2734
- st.session_state['trained_adain_model'],
2735
- alpha=alpha
2736
- )
2737
- if result:
2738
- st.image(result, caption="Styled Result", use_column_width=True)
2739
-
2740
- # Download button
2741
- buf = io.BytesIO()
2742
- result.save(buf, format='PNG')
2743
- st.download_button(
2744
- label="📥 Download Result",
2745
- data=buf.getvalue(),
2746
- file_name=f"adain_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
2747
- mime="image/png"
 
 
2748
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2749
 
2750
  # Model download section
2751
  st.markdown("---")
@@ -2756,13 +3016,73 @@ with tab4:
2756
  st.download_button(
2757
  label="Download Trained AdaIN Model",
2758
  data=f.read(),
2759
- file_name=f"{model_name}_adain_final.pth",
2760
  mime="application/octet-stream",
2761
  use_container_width=True
2762
  )
2763
  with col_dl2:
2764
  st.info("This model can be loaded and used for real-time style transfer")
2765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2766
  # TAB 5: Batch Processing
2767
  with tab5:
2768
  st.header("Batch Processing")
@@ -2915,7 +3235,7 @@ with tab6:
2915
  - Supports all style combinations and blend modes
2916
  - Enhanced codec compatibility
2917
 
2918
- #### 🔧 Custom Training
2919
  - Train on any artistic style with minimal data (1-50 images)
2920
  - Automatic data augmentation for small datasets
2921
  - Adjustable model complexity (3-12 residual blocks)
@@ -2976,4 +3296,4 @@ with tab6:
2976
 
2977
  # Footer
2978
  st.markdown("---")
2979
- st.markdown("Style transfer system with state-of-the-art CycleGAN models and regional painting capabilities.")
 
46
  # Set page config
47
  st.set_page_config(
48
  page_title="Style Transfer Studio",
49
+ # page_icon="",
50
  layout="wide",
51
  initial_sidebar_state="expanded"
52
  )
 
1344
  return model
1345
  except:
1346
  return None
1347
+ # Inside the StyleTransferSystem class, add these methods:
1348
+
1349
+ def _create_linear_weight(self, width, height, overlap):
1350
+ """Create linear blending weights for tile edges"""
1351
+ weight = np.ones((height, width, 1), dtype=np.float32)
1352
+
1353
+ if overlap > 0:
1354
+ # Create gradients for each edge
1355
+ for i in range(overlap):
1356
+ alpha = i / overlap
1357
+ # Top edge
1358
+ weight[i, :] *= alpha
1359
+ # Bottom edge
1360
+ weight[-i-1, :] *= alpha
1361
+ # Left edge
1362
+ weight[:, i] *= alpha
1363
+ # Right edge
1364
+ weight[:, -i-1] *= alpha
1365
+
1366
+ return weight
1367
+
1368
+ def _create_gaussian_weight(self, width, height, overlap):
1369
+ """Create Gaussian blending weights for smoother transitions"""
1370
+ weight = np.ones((height, width), dtype=np.float32)
1371
+
1372
+ # Create 2D Gaussian centered in the tile
1373
+ y, x = np.ogrid[:height, :width]
1374
+ center_y, center_x = height / 2, width / 2
1375
+
1376
+ # Distance from center
1377
+ dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
1378
+
1379
+ # Gaussian falloff starting from the edges
1380
+ max_dist = min(height, width) / 2
1381
+ sigma = max_dist / 2 # Adjust for smoother/sharper transitions
1382
+
1383
+ # Apply Gaussian only near edges
1384
+ edge_dist = np.minimum(
1385
+ np.minimum(y, height - 1 - y),
1386
+ np.minimum(x, width - 1 - x)
1387
+ )
1388
+
1389
+ # Weight is 1 in center, Gaussian falloff near edges
1390
+ weight = np.where(
1391
+ edge_dist < overlap,
1392
+ np.exp(-0.5 * ((overlap - edge_dist) / (overlap/3))**2),
1393
+ 1.0
1394
+ )
1395
+
1396
+ return weight.reshape(height, width, 1)
1397
 
1398
  def apply_lightweight_style(self, image, model, intensity=1.0):
1399
  """Apply style using a lightweight model"""
 
1704
 
1705
  return model
1706
 
1707
+ def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False):
1708
+ """Apply AdaIN-based style transfer with optional tiling"""
1709
+ if use_tiling and (content_image.width > 512 or content_image.height > 512):
1710
+ # Use tiling for large images
1711
+ return self.apply_adain_style_tiled(
1712
+ content_image, style_image, model, alpha,
1713
+ tile_size=256, # Match training size
1714
+ overlap=32,
1715
+ blend_mode='gaussian'
1716
+ )
1717
+
1718
+ # Original implementation for small images
1719
  if content_image is None or style_image is None or model is None:
1720
  return None
1721
 
 
1725
 
1726
  original_size = content_image.size
1727
 
1728
+ # Transform for AdaIN (VGG normalization)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1729
  transform = transforms.Compose([
1730
+ transforms.Resize((256, 256)), # Direct resize, no cropping
1731
  transforms.ToTensor(),
1732
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1733
  std=[0.229, 0.224, 0.225])
 
1737
  style_tensor = transform(style_image).unsqueeze(0).to(self.device)
1738
 
1739
  with torch.no_grad():
 
 
 
 
 
1740
  output = model(content_tensor, style_tensor, alpha=alpha)
1741
 
1742
+ # Denormalize
1743
  output = output.squeeze(0).cpu()
1744
+ output = output * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
1745
+ output = output + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
 
 
 
 
 
1746
  output = torch.clamp(output, 0, 1)
1747
 
1748
+ # Convert to PIL
1749
  output_img = transforms.ToPILImage()(output)
1750
+ output_img = output_img.resize(original_size, Image.LANCZOS)
 
 
 
 
 
 
 
 
1751
 
1752
  return output_img
1753
 
1754
+ except Exception as e:
1755
+ print(f"Error applying AdaIN style: {e}")
1756
+ traceback.print_exc()
1757
+ return None
1758
+
1759
+ def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0,
1760
+ tile_size=256, overlap=32, blend_mode='linear'):
1761
+ """
1762
+ Apply AdaIN style transfer using tiling for high-quality results.
1763
+ Processes image in overlapping tiles to maintain quality.
1764
+ """
1765
+ if content_image is None or style_image is None or model is None:
1766
+ return None
1767
+
1768
+ try:
1769
+ model = model.to(self.device)
1770
+ model.eval()
1771
+
1772
+ # Prepare transforms
1773
+ transform = transforms.Compose([
1774
+ transforms.Resize((tile_size, tile_size)),
1775
+ transforms.ToTensor(),
1776
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
1777
+ std=[0.229, 0.224, 0.225])
1778
+ ])
1779
+
1780
+ # Process style image once (at tile size)
1781
+ style_tensor = transform(style_image).unsqueeze(0).to(self.device)
1782
+
1783
+ # Get dimensions
1784
+ w, h = content_image.size
1785
+
1786
+ # Calculate tile positions with overlap
1787
+ stride = tile_size - overlap
1788
+ tiles_x = list(range(0, w - tile_size + 1, stride))
1789
+ tiles_y = list(range(0, h - tile_size + 1, stride))
1790
+
1791
+ # Ensure we cover the entire image
1792
+ if tiles_x[-1] + tile_size < w:
1793
+ tiles_x.append(w - tile_size)
1794
+ if tiles_y[-1] + tile_size < h:
1795
+ tiles_y.append(h - tile_size)
1796
+
1797
+ # If image is smaller than tile size, just process normally
1798
+ if w <= tile_size and h <= tile_size:
1799
+ return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False)
1800
+
1801
+ print(f"Processing {len(tiles_x) * len(tiles_y)} tiles of size {tile_size}x{tile_size}")
1802
+
1803
+ # Initialize output and weight arrays
1804
+ output_array = np.zeros((h, w, 3), dtype=np.float32)
1805
+ weight_array = np.zeros((h, w, 1), dtype=np.float32)
1806
+
1807
+ # Process each tile
1808
+ with torch.no_grad():
1809
+ for y_idx, y in enumerate(tiles_y):
1810
+ for x_idx, x in enumerate(tiles_x):
1811
+ # Extract tile
1812
+ tile = content_image.crop((x, y, x + tile_size, y + tile_size))
1813
+
1814
+ # Transform tile
1815
+ tile_tensor = transform(tile).unsqueeze(0).to(self.device)
1816
+
1817
+ # Apply AdaIN to tile
1818
+ styled_tensor = model(tile_tensor, style_tensor, alpha=alpha)
1819
+
1820
+ # Denormalize
1821
+ styled_tensor = styled_tensor.squeeze(0).cpu()
1822
+ denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
1823
+ denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
1824
+ styled_tensor = styled_tensor * denorm_std + denorm_mean
1825
+ styled_tensor = torch.clamp(styled_tensor, 0, 1)
1826
+
1827
+ # Convert to numpy
1828
+ styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255
1829
+
1830
+ # Create weight mask for blending
1831
+ if blend_mode == 'gaussian':
1832
+ weight = self._create_gaussian_weight(tile_size, tile_size, overlap)
1833
+ else:
1834
+ weight = self._create_linear_weight(tile_size, tile_size, overlap)
1835
+
1836
+ # Add to output with weights
1837
+ output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight
1838
+ weight_array[y:y+tile_size, x:x+tile_size] += weight
1839
+
1840
+ # Normalize by weights
1841
+ output_array = output_array / (weight_array + 1e-8)
1842
+ output_array = np.clip(output_array, 0, 255).astype(np.uint8)
1843
+
1844
+ return Image.fromarray(output_array)
1845
+
1846
+ except Exception as e:
1847
+ print(f"Error in tiled AdaIN processing: {e}")
1848
+ traceback.print_exc()
1849
+ # Fallback to standard processing
1850
+ return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False)
1851
 
1852
 
1853
  # ===========================
 
1884
 
1885
  return combined_mask
1886
 
1887
+ def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False):
1888
+ """Apply AdaIN style transfer to a painted region only"""
1889
+ if content_image is None or style_image is None or model is None:
1890
+ return None
1891
+
1892
+ try:
1893
+ # Get the mask from canvas
1894
+ if canvas_result is None or canvas_result.image_data is None:
1895
+ # No mask painted, apply to whole image
1896
+ return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
1897
+
1898
+ # Extract mask from canvas
1899
+ mask_data = canvas_result.image_data[:, :, 3] # Alpha channel
1900
+ mask = mask_data > 0
1901
+
1902
+ # Resize mask to match original image size
1903
+ original_size = content_image.size
1904
+ display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0])
1905
+
1906
+ if original_size != display_size:
1907
+ # Convert mask to PIL image for resizing
1908
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
1909
+ mask_pil = mask_pil.resize(original_size, Image.NEAREST)
1910
+ mask = np.array(mask_pil) > 128
1911
+
1912
+ # Apply feathering to mask edges if requested
1913
+ if feather_radius > 0:
1914
+ from scipy.ndimage import gaussian_filter
1915
+ mask_float = mask.astype(np.float32)
1916
+ mask_float = gaussian_filter(mask_float, sigma=feather_radius)
1917
+ mask_float = np.clip(mask_float, 0, 1)
1918
+ else:
1919
+ mask_float = mask.astype(np.float32)
1920
+
1921
+ # Apply style to entire image with tiling option
1922
+ styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
1923
+
1924
+ if styled_full is None:
1925
+ return None
1926
+
1927
+ # Blend original and styled based on mask
1928
+ original_array = np.array(content_image, dtype=np.float32)
1929
+ styled_array = np.array(styled_full, dtype=np.float32)
1930
+
1931
+ # Expand mask to 3 channels
1932
+ mask_3ch = np.stack([mask_float] * 3, axis=2)
1933
+
1934
+ # Blend
1935
+ result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch
1936
+ result_array = np.clip(result_array, 0, 255).astype(np.uint8)
1937
+
1938
+ return Image.fromarray(result_array)
1939
+
1940
+ except Exception as e:
1941
+ print(f"Error applying regional AdaIN style: {e}")
1942
+ traceback.print_exc()
1943
+ return None
1944
+
1945
  # ===========================
1946
  # INITIALIZE SYSTEM AND API
1947
  # ===========================
 
2007
  st.caption("For faster processing, use a GPU-enabled environment")
2008
 
2009
  st.markdown("---")
2010
+ st.markdown("### Quick Guide")
2011
  st.markdown("""
2012
  - **Style Transfer**: Apply artistic styles to images
2013
  - **Regional Transform**: Paint areas for local effects
 
2277
  # Region management
2278
  col_btn1, col_btn2, col_btn3 = st.columns(3)
2279
  with col_btn1:
2280
+ if st.button("Add Region", use_container_width=True):
2281
  new_region = {
2282
  'id': len(st.session_state.regions),
2283
  'style': style_choices[0] if style_choices else None,
 
2325
  key=f"region_intensity_{i}"
2326
  )
2327
 
2328
+ if st.button(f"Remove Region {i+1}", key=f"remove_region_{i}"):
2329
  # Remove the region
2330
  st.session_state.regions.pop(i)
2331
 
 
2357
  # Show workflow status
2358
  if 'regional_result' in st.session_state:
2359
  if st.session_state.canvas_ready:
2360
+ st.success("Edit Mode - Paint your regions and click 'Apply Regional Styles' when ready")
2361
  else:
2362
+ st.info("Preview Mode - Click 'Continue Editing' to modify regions")
2363
  else:
2364
  st.info("Paint on the canvas below to define regions for each style")
2365
 
 
2387
 
2388
  current_region = st.session_state.regions[current_region_idx]
2389
 
 
2390
  col_draw1, col_draw2, col_draw3 = st.columns(3)
2391
 
2392
  with col_draw1:
 
2487
 
2488
  st.session_state['regional_result'] = result
2489
 
2490
+ # Show result with fixed size
2491
  if 'regional_result' in st.session_state:
2492
  st.subheader("Result")
2493
+
2494
+ # Add display size control
2495
+ display_size = st.slider("Display Size", 300, 800, 600, 50, key="regional_display_size")
2496
+
2497
+ # Fixed size display
2498
+ result_display = resize_image_for_display(
2499
+ st.session_state['regional_result'],
2500
+ max_width=display_size,
2501
+ max_height=display_size
2502
+ )
2503
+ st.image(result_display, caption="Regional Styled Image")
2504
+
2505
+ # Show actual dimensions
2506
+ st.caption(f"Original size: {st.session_state['regional_result'].size[0]}x{st.session_state['regional_result'].size[1]} pixels")
2507
 
2508
  # Download button
2509
  buf = io.BytesIO()
 
2514
  file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
2515
  mime="image/png"
2516
  )
2517
+
 
 
 
2518
  # TAB 3: Video Processing
2519
  with tab3:
2520
  st.header("Video Processing")
 
2705
 
2706
 
2707
 
2708
+ # TAB 4: Training with AdaIN and Regional Application
2709
  with tab4:
2710
  st.header("Train Custom Style with AdaIN")
2711
  st.markdown("Train your own style transfer model using Adaptive Instance Normalization")
 
2713
  # Initialize session state for content images
2714
  if 'content_images_list' not in st.session_state:
2715
  st.session_state.content_images_list = []
2716
+ if 'adain_canvas_result' not in st.session_state:
2717
+ st.session_state.adain_canvas_result = None
2718
+ if 'adain_test_image' not in st.session_state:
2719
+ st.session_state.adain_test_image = None
2720
 
2721
  col1, col2, col3 = st.columns([1, 1, 1])
2722
 
 
2757
  st.caption(f"... and {len(content_imgs) - 3} more")
2758
 
2759
  with col3:
2760
+ st.subheader("Training Settings")
2761
 
2762
  model_name = st.text_input("Model Name",
2763
  value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
 
2818
  st.session_state['trained_adain_model'] = model
2819
  st.session_state['trained_style_images'] = style_images
2820
  st.session_state['model_path'] = f'/tmp/trained_models/{model_name}_final.pth'
2821
+ st.success("AdaIN training complete")
2822
 
2823
  progress_bar.empty()
2824
  status_text.empty()
2825
  else:
2826
  st.error("Please upload both style and content images")
2827
 
2828
+ # Testing section with regional application
2829
  if 'trained_adain_model' in st.session_state:
2830
  st.markdown("---")
2831
  st.header("Test Your AdaIN Model")
2832
 
2833
+ # Application mode selection
2834
+ application_mode = st.radio("Application Mode",
2835
+ ["Whole Image", "Paint Region"],
2836
+ horizontal=True,
2837
+ help="Choose whether to apply style to entire image or paint specific regions")
2838
+
2839
  test_col1, test_col2, test_col3 = st.columns([1, 1, 1])
2840
 
2841
  with test_col1:
 
2843
 
2844
  # Test image selection
2845
  test_source = st.radio("Test Image Source",
2846
+ ["Use Content Image", "Upload New", "Use Unsplash Image"],
2847
  horizontal=True)
2848
 
2849
  test_image = None
 
2853
  range(len(st.session_state.content_images_list)),
2854
  format_func=lambda x: f"Content Image {x+1}")
2855
  test_image = Image.open(st.session_state.content_images_list[content_idx]).convert('RGB')
2856
+ elif test_source == "Use Unsplash Image":
2857
+ # Use current Unsplash image if available
2858
+ if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
2859
+ test_image = st.session_state['current_image']
2860
+ st.success("Using Unsplash image")
2861
+ else:
2862
+ st.info("Please search and select an image from the Style Transfer tab first")
2863
  else:
2864
  # Upload new image
2865
  test_upload = st.file_uploader("Upload test image",
 
2868
  if test_upload:
2869
  test_image = Image.open(test_upload).convert('RGB')
2870
 
2871
+ # Store test image in session state
2872
+ if test_image:
2873
+ st.session_state['adain_test_image'] = test_image
2874
+
2875
  # Style selection for testing
2876
  if 'trained_style_images' in st.session_state and len(st.session_state['trained_style_images']) > 1:
2877
  style_idx = st.selectbox("Select style",
 
2887
  # Alpha blending control
2888
  alpha = st.slider("Style Strength (Alpha)", 0.0, 1.0, 1.0, 0.1,
2889
  help="0 = original content, 1 = full style transfer")
2890
+
2891
+ # Add tiling option
2892
+ use_tiling = st.checkbox("🔲 Use Tiled Processing",
2893
+ value=True,
2894
+ help="Process large images in tiles for better quality. Recommended for images larger than 512x512.")
2895
+
2896
+ # Regional painting options (only show if in paint mode)
2897
+ if application_mode == "🖌️ Paint Region":
2898
+ st.markdown("---")
2899
+ st.subheader("🖌️ Painting Options")
2900
+
2901
+ brush_size = st.slider("Brush Size", 5, 100, 30)
2902
+ drawing_mode = st.selectbox("Drawing Tool",
2903
+ ["freedraw", "line", "rect", "circle", "polygon"],
2904
+ index=0)
2905
+
2906
+ # Feather/blur the mask edges
2907
+ feather_radius = st.slider("Edge Softness", 0, 50, 10,
2908
+ help="Blur mask edges for smoother transitions")
2909
+
2910
+ col_btn1, col_btn2 = st.columns(2)
2911
+ with col_btn1:
2912
+ if st.button("Clear Canvas", use_container_width=True):
2913
+ st.session_state['adain_canvas_result'] = None
2914
+ st.rerun()
2915
+
2916
+ with col_btn2:
2917
+ if st.button("Reset Result", use_container_width=True):
2918
+ if 'adain_styled_result' in st.session_state:
2919
+ del st.session_state['adain_styled_result']
2920
+ st.rerun()
2921
 
2922
  with test_col2:
2923
+ st.subheader("Canvas / Original")
2924
+
2925
+ if application_mode == "Paint Region" and test_image:
2926
+ # Show canvas for painting
2927
+ display_img = resize_image_for_display(test_image, max_width=400, max_height=400)
2928
+ canvas_width, canvas_height = display_img.size
2929
+
2930
+ st.info("🖌️ Paint the areas where you want to apply the style")
2931
+
2932
+ # Canvas for painting mask
2933
+ canvas_result = st_canvas(
2934
+ fill_color="rgba(255, 0, 0, 0.3)", # Red with transparency
2935
+ stroke_width=brush_size,
2936
+ stroke_color="rgba(255, 0, 0, 0.5)",
2937
+ background_image=display_img,
2938
+ update_streamlit=True,
2939
+ height=canvas_height,
2940
+ width=canvas_width,
2941
+ drawing_mode=drawing_mode,
2942
+ display_toolbar=True,
2943
+ key=f"adain_canvas_{brush_size}_{drawing_mode}"
2944
+ )
2945
+
2946
+ # Save canvas result
2947
+ if canvas_result:
2948
+ st.session_state['adain_canvas_result'] = canvas_result
2949
+
2950
+ # Show style image below canvas
2951
  if test_style:
2952
+ st.markdown("---")
2953
  st.image(test_style, caption="Style Image", use_column_width=True)
2954
+
2955
+ else:
2956
+ # Show original images
2957
+ if test_image:
2958
+ st.image(test_image, caption="Content Image", use_column_width=True)
2959
+ if test_style:
2960
+ st.image(test_style, caption="Style Image", use_column_width=True)
2961
 
2962
  with test_col3:
2963
  st.subheader("Result")
2964
+
2965
+ # Apply button
2966
+ apply_button = st.button("Apply Style", type="primary", use_container_width=True)
2967
+
2968
+ if apply_button and test_image and test_style:
2969
  with st.spinner("Applying style..."):
2970
+ if application_mode == "🖼️ Whole Image":
2971
+ # Apply to whole image
2972
+ result = system.apply_adain_style(
2973
+ test_image,
2974
+ test_style,
2975
+ st.session_state['trained_adain_model'],
2976
+ alpha=alpha,
2977
+ use_tiling=use_tiling
2978
+ )
2979
+ else:
2980
+ # Apply to painted region
2981
+ result = apply_adain_regional(
2982
+ test_image,
2983
+ test_style,
2984
+ st.session_state['trained_adain_model'],
2985
+ st.session_state.get('adain_canvas_result'),
2986
+ alpha=alpha,
2987
+ feather_radius=feather_radius,
2988
+ use_tiling=use_tiling
2989
  )
2990
+
2991
+ if result:
2992
+ st.session_state['adain_styled_result'] = result
2993
+
2994
+ # Show result if available
2995
+ if 'adain_styled_result' in st.session_state:
2996
+ st.image(st.session_state['adain_styled_result'],
2997
+ caption="Styled Result",
2998
+ use_column_width=True)
2999
+
3000
+ # Download button
3001
+ buf = io.BytesIO()
3002
+ st.session_state['adain_styled_result'].save(buf, format='PNG')
3003
+ st.download_button(
3004
+ label="Download Result",
3005
+ data=buf.getvalue(),
3006
+ file_name=f"adain_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
3007
+ mime="image/png"
3008
+ )
3009
 
3010
  # Model download section
3011
  st.markdown("---")
 
3016
  st.download_button(
3017
  label="Download Trained AdaIN Model",
3018
  data=f.read(),
3019
+ file_name=f"{st.session_state.get('model_name', 'adain')}_final.pth",
3020
  mime="application/octet-stream",
3021
  use_container_width=True
3022
  )
3023
  with col_dl2:
3024
  st.info("This model can be loaded and used for real-time style transfer")
3025
 
3026
+
3027
+ # Add this helper function (place it before the tab or with other helper functions)
3028
+ def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False):
3029
+ """Apply AdaIN style transfer to a painted region only"""
3030
+ if content_image is None or style_image is None or model is None:
3031
+ return None
3032
+
3033
+ try:
3034
+ # Get the mask from canvas
3035
+ if canvas_result is None or canvas_result.image_data is None:
3036
+ # No mask painted, apply to whole image
3037
+ return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
3038
+
3039
+ # Extract mask from canvas
3040
+ mask_data = canvas_result.image_data[:, :, 3] # Alpha channel
3041
+ mask = mask_data > 0
3042
+
3043
+ # Resize mask to match original image size
3044
+ original_size = content_image.size
3045
+ display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0])
3046
+
3047
+ if original_size != display_size:
3048
+ # Convert mask to PIL image for resizing
3049
+ mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
3050
+ mask_pil = mask_pil.resize(original_size, Image.NEAREST)
3051
+ mask = np.array(mask_pil) > 128
3052
+
3053
+ # Apply feathering to mask edges if requested
3054
+ if feather_radius > 0:
3055
+ from scipy.ndimage import gaussian_filter
3056
+ mask_float = mask.astype(np.float32)
3057
+ mask_float = gaussian_filter(mask_float, sigma=feather_radius)
3058
+ mask_float = np.clip(mask_float, 0, 1)
3059
+ else:
3060
+ mask_float = mask.astype(np.float32)
3061
+
3062
+ # Apply style to entire image with tiling option
3063
+ styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
3064
+
3065
+ if styled_full is None:
3066
+ return None
3067
+
3068
+ # Blend original and styled based on mask
3069
+ original_array = np.array(content_image, dtype=np.float32)
3070
+ styled_array = np.array(styled_full, dtype=np.float32)
3071
+
3072
+ # Expand mask to 3 channels
3073
+ mask_3ch = np.stack([mask_float] * 3, axis=2)
3074
+
3075
+ # Blend
3076
+ result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch
3077
+ result_array = np.clip(result_array, 0, 255).astype(np.uint8)
3078
+
3079
+ return Image.fromarray(result_array)
3080
+
3081
+ except Exception as e:
3082
+ print(f"Error applying regional AdaIN style: {e}")
3083
+ traceback.print_exc()
3084
+ return None
3085
+
3086
  # TAB 5: Batch Processing
3087
  with tab5:
3088
  st.header("Batch Processing")
 
3235
  - Supports all style combinations and blend modes
3236
  - Enhanced codec compatibility
3237
 
3238
+ #### Custom Training
3239
  - Train on any artistic style with minimal data (1-50 images)
3240
  - Automatic data augmentation for small datasets
3241
  - Adjustable model complexity (3-12 residual blocks)
 
3296
 
3297
  # Footer
3298
  st.markdown("---")
3299
+ st.markdown("Style transfer system with CycleGAN models and regional painting capabilities.")