Daniel Roxas commited on
Commit
d731f27
·
verified ·
1 Parent(s): b7d23c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -65
app.py CHANGED
@@ -1570,14 +1570,17 @@ class StyleTransferSystem:
1570
  return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
1571
 
1572
  def train_adain_model(self, style_images, content_dir, model_name,
1573
- epochs=30, batch_size=4, lr=1e-4,
1574
- save_interval=5, style_weight=10.0, content_weight=1.0,
1575
- progress_callback=None):
1576
  """Train an AdaIN-based style transfer model"""
1577
 
1578
  model = AdaINStyleTransfer().to(self.device)
1579
  optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
1580
 
 
 
 
1581
  print(f"Training AdaIN model")
1582
  print(f"Training device: {self.device}")
1583
 
@@ -1586,24 +1589,32 @@ class StyleTransferSystem:
1586
  print(f"Model on GPU: {next(model.decoder.parameters()).device}")
1587
  print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
1588
 
1589
- # Prepare style images
1590
  style_transform = transforms.Compose([
1591
- transforms.Resize(512),
1592
- transforms.RandomCrop(256),
 
1593
  transforms.ToTensor(),
1594
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1595
  std=[0.229, 0.224, 0.225])
1596
  ])
1597
 
1598
  style_tensors = []
 
1599
  for style_img in style_images:
1600
- style_tensor = style_transform(style_img).unsqueeze(0).to(self.device)
1601
- style_tensors.append(style_tensor)
 
 
 
 
1602
 
1603
- # Prepare content dataset
1604
  content_transform = transforms.Compose([
1605
- transforms.Resize(512),
1606
- transforms.RandomCrop(256),
 
 
1607
  transforms.ToTensor(),
1608
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1609
  std=[0.229, 0.224, 0.225])
@@ -1617,9 +1628,28 @@ class StyleTransferSystem:
1617
  print(f" - Content images: {len(dataset)}")
1618
  print(f" - Batch size: {batch_size}")
1619
  print(f" - Epochs: {epochs}")
1620
-
1621
- # Loss network (VGG for perceptual loss)
1622
- loss_network = VGGEncoder().to(self.device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1623
  mse_loss = nn.MSELoss()
1624
 
1625
  # Training loop
@@ -1627,6 +1657,9 @@ class StyleTransferSystem:
1627
  model.encoder.eval() # Keep encoder frozen
1628
  total_steps = 0
1629
 
 
 
 
1630
  for epoch in range(epochs):
1631
  epoch_loss = 0
1632
 
@@ -1643,33 +1676,42 @@ class StyleTransferSystem:
1643
  # Forward pass
1644
  output = model(content_batch, batch_style)
1645
 
1646
- # Content loss
1647
  with torch.no_grad():
1648
- content_feat = loss_network.encode(content_batch)
1649
- output_feat = loss_network.encode(output)
1650
- content_loss = mse_loss(output_feat, content_feat)
1651
 
1652
- # Style loss
1653
- with torch.no_grad():
1654
- style_feat = loss_network.encode(batch_style)
 
 
 
1655
 
1656
- # Compute style loss using Gram matrices
1657
  def gram_matrix(feat):
1658
  b, c, h, w = feat.size()
1659
  feat = feat.view(b, c, h * w)
1660
  gram = torch.bmm(feat, feat.transpose(1, 2))
1661
  return gram / (c * h * w)
1662
 
1663
- output_gram = gram_matrix(output_feat)
1664
- style_gram = gram_matrix(style_feat)
1665
- style_loss = mse_loss(output_gram, style_gram)
 
 
 
1666
 
1667
  # Total loss
1668
- loss = content_weight * content_loss + style_weight * style_loss
1669
 
1670
  # Backward pass
1671
  optimizer.zero_grad()
1672
  loss.backward()
 
 
 
 
1673
  optimizer.step()
1674
 
1675
  epoch_loss += loss.item()
@@ -1678,7 +1720,12 @@ class StyleTransferSystem:
1678
  # Progress callback
1679
  if progress_callback and total_steps % 10 == 0:
1680
  progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
1681
- progress_callback(progress, f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
 
 
 
 
 
1682
 
1683
  # Save checkpoint
1684
  if (epoch + 1) % save_interval == 0:
@@ -1687,6 +1734,7 @@ class StyleTransferSystem:
1687
  'epoch': epoch + 1,
1688
  'model_state_dict': model.state_dict(),
1689
  'optimizer_state_dict': optimizer.state_dict(),
 
1690
  'loss': epoch_loss / len(dataloader),
1691
  'model_type': 'adain'
1692
  }, checkpoint_path)
@@ -1704,19 +1752,21 @@ class StyleTransferSystem:
1704
  self.lightweight_models[model_name] = model
1705
 
1706
  return model
 
 
 
1707
 
1708
  def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False):
1709
  """Apply AdaIN-based style transfer with optional tiling"""
1710
- if use_tiling and (content_image.width > 512 or content_image.height > 512):
1711
- # Use tiling for large images
1712
  return self.apply_adain_style_tiled(
1713
  content_image, style_image, model, alpha,
1714
- tile_size=256, # Match training size
1715
- overlap=32,
1716
  blend_mode='gaussian'
1717
  )
1718
 
1719
- # Original implementation for small images
1720
  if content_image is None or style_image is None or model is None:
1721
  return None
1722
 
@@ -1726,9 +1776,23 @@ class StyleTransferSystem:
1726
 
1727
  original_size = content_image.size
1728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1729
  # Transform for AdaIN (VGG normalization)
1730
  transform = transforms.Compose([
1731
- transforms.Resize((256, 256)), # Direct resize, no cropping
1732
  transforms.ToTensor(),
1733
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1734
  std=[0.229, 0.224, 0.225])
@@ -1756,7 +1820,7 @@ class StyleTransferSystem:
1756
  print(f"Error applying AdaIN style: {e}")
1757
  traceback.print_exc()
1758
  return None
1759
-
1760
  def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0,
1761
  tile_size=256, overlap=32, blend_mode='linear'):
1762
  """
@@ -1770,6 +1834,10 @@ class StyleTransferSystem:
1770
  model = model.to(self.device)
1771
  model.eval()
1772
 
 
 
 
 
1773
  # Prepare transforms
1774
  transform = transforms.Compose([
1775
  transforms.Resize((tile_size, tile_size)),
@@ -1790,10 +1858,10 @@ class StyleTransferSystem:
1790
  tiles_y = list(range(0, h - tile_size + 1, stride))
1791
 
1792
  # Ensure we cover the entire image
1793
- if tiles_x[-1] + tile_size < w:
1794
- tiles_x.append(w - tile_size)
1795
- if tiles_y[-1] + tile_size < h:
1796
- tiles_y.append(h - tile_size)
1797
 
1798
  # If image is smaller than tile size, just process normally
1799
  if w <= tile_size and h <= tile_size:
@@ -1828,11 +1896,8 @@ class StyleTransferSystem:
1828
  # Convert to numpy
1829
  styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255
1830
 
1831
- # Create weight mask for blending
1832
- if blend_mode == 'gaussian':
1833
- weight = self._create_gaussian_weight(tile_size, tile_size, overlap)
1834
- else:
1835
- weight = self._create_linear_weight(tile_size, tile_size, overlap)
1836
 
1837
  # Add to output with weights
1838
  output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight
@@ -2706,6 +2771,7 @@ with tab3:
2706
 
2707
 
2708
 
 
2709
  # TAB 4: Training with AdaIN and Regional Application
2710
  with tab4:
2711
  st.header("Train Custom Style with AdaIN")
@@ -2739,7 +2805,7 @@ with tab4:
2739
 
2740
  with col2:
2741
  st.subheader("Content Images")
2742
- content_imgs = st.file_uploader("Upload content images (5-50 recommended)",
2743
  type=['png', 'jpg', 'jpeg'],
2744
  accept_multiple_files=True,
2745
  key="train_content_adain")
@@ -2763,22 +2829,26 @@ with tab4:
2763
  model_name = st.text_input("Model Name",
2764
  value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
2765
 
2766
- epochs = st.slider("Training Epochs", 5, 50, 15, 5)
 
2767
  batch_size = st.slider("Batch Size", 1, 8, 4)
2768
  learning_rate = st.number_input("Learning Rate", 0.00001, 0.001, 0.0001, format="%.5f")
2769
 
2770
  with st.expander("Advanced Settings"):
2771
- style_weight = st.number_input("Style Weight", 1.0, 100.0, 10.0, 1.0)
 
2772
  content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1)
2773
- save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20, 5, 5)
 
 
2774
 
2775
  st.markdown("---")
2776
 
2777
  # Training button
2778
  if st.button("Start AdaIN Training", type="primary", use_container_width=True):
2779
  if style_imgs and content_imgs:
2780
- if len(content_imgs) < 5:
2781
- st.warning("For best results, use at least 5 content images")
2782
 
2783
  with st.spinner("Training AdaIN model..."):
2784
  progress_bar = st.progress(0)
@@ -2803,14 +2873,170 @@ with tab4:
2803
  style_img = Image.open(style_file).convert('RGB')
2804
  style_images.append(style_img)
2805
 
2806
- # Train model
2807
- model = system.train_adain_model(
2808
- style_images, temp_content_dir, model_name,
2809
- epochs=epochs, lr=learning_rate, batch_size=batch_size,
2810
- save_interval=save_interval, style_weight=style_weight,
2811
- content_weight=content_weight,
2812
- progress_callback=progress_callback
2813
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2814
 
2815
  # Cleanup
2816
  shutil.rmtree(temp_content_dir)
@@ -2818,8 +3044,11 @@ with tab4:
2818
  if model:
2819
  st.session_state['trained_adain_model'] = model
2820
  st.session_state['trained_style_images'] = style_images
2821
- st.session_state['model_path'] = f'/tmp/trained_models/{model_name}_final.pth'
2822
- st.success("AdaIN training complete")
 
 
 
2823
 
2824
  progress_bar.empty()
2825
  status_text.empty()
@@ -2885,14 +3114,15 @@ with tab4:
2885
  else:
2886
  test_style = None
2887
 
 
2888
  # Alpha blending control
2889
- alpha = st.slider("Style Strength (Alpha)", 0.0, 1.0, 1.0, 0.1,
2890
- help="0 = original content, 1 = full style transfer")
2891
 
2892
- # Add tiling option
2893
  use_tiling = st.checkbox("Use Tiled Processing",
2894
- value=True,
2895
- help="Process large images in tiles for better quality. Recommended for images larger than 512x512.")
2896
 
2897
  # Initialize variables with default values
2898
  brush_size = 30
@@ -3003,6 +3233,16 @@ with tab4:
3003
  caption="Styled Result",
3004
  use_column_width=True)
3005
 
 
 
 
 
 
 
 
 
 
 
3006
  # Download button
3007
  buf = io.BytesIO()
3008
  st.session_state['adain_styled_result'].save(buf, format='PNG')
@@ -3022,7 +3262,7 @@ with tab4:
3022
  st.download_button(
3023
  label="Download Trained AdaIN Model",
3024
  data=f.read(),
3025
- file_name=f"{st.session_state.get('model_name', 'adain')}_final.pth",
3026
  mime="application/octet-stream",
3027
  use_container_width=True
3028
  )
 
1570
  return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
1571
 
1572
  def train_adain_model(self, style_images, content_dir, model_name,
1573
+ epochs=30, batch_size=4, lr=1e-4,
1574
+ save_interval=5, style_weight=10.0, content_weight=1.0,
1575
+ progress_callback=None):
1576
  """Train an AdaIN-based style transfer model"""
1577
 
1578
  model = AdaINStyleTransfer().to(self.device)
1579
  optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
1580
 
1581
+ # Add learning rate scheduler
1582
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
1583
+
1584
  print(f"Training AdaIN model")
1585
  print(f"Training device: {self.device}")
1586
 
 
1589
  print(f"Model on GPU: {next(model.decoder.parameters()).device}")
1590
  print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
1591
 
1592
+ # Prepare style images - INCREASED SIZE
1593
  style_transform = transforms.Compose([
1594
+ transforms.Resize(600), # Increased from 512
1595
+ transforms.RandomCrop(512), # Increased from 256
1596
+ transforms.RandomHorizontalFlip(p=0.5), # Add augmentation
1597
  transforms.ToTensor(),
1598
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1599
  std=[0.229, 0.224, 0.225])
1600
  ])
1601
 
1602
  style_tensors = []
1603
+ # Create multiple augmented versions of each style image
1604
  for style_img in style_images:
1605
+ # Generate 5 augmented versions per style image
1606
+ for _ in range(5):
1607
+ style_tensor = style_transform(style_img).unsqueeze(0).to(self.device)
1608
+ style_tensors.append(style_tensor)
1609
+
1610
+ print(f"Created {len(style_tensors)} augmented style samples from {len(style_images)} images")
1611
 
1612
+ # Prepare content dataset - INCREASED SIZE
1613
  content_transform = transforms.Compose([
1614
+ transforms.Resize(600), # Increased from 512
1615
+ transforms.RandomCrop(512), # Increased from 256
1616
+ transforms.RandomHorizontalFlip(),
1617
+ transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
1618
  transforms.ToTensor(),
1619
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1620
  std=[0.229, 0.224, 0.225])
 
1628
  print(f" - Content images: {len(dataset)}")
1629
  print(f" - Batch size: {batch_size}")
1630
  print(f" - Epochs: {epochs}")
1631
+ print(f" - Training resolution: 512x512") # Updated
1632
+
1633
+ # Loss network (VGG for perceptual loss) - USE MULTIPLE LAYERS
1634
+ class MultiLayerVGG(nn.Module):
1635
+ def __init__(self):
1636
+ super().__init__()
1637
+ vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
1638
+ self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1
1639
+ self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1
1640
+ self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1
1641
+ self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1
1642
+ for param in self.parameters():
1643
+ param.requires_grad = False
1644
+
1645
+ def forward(self, x):
1646
+ h1 = self.slice1(x)
1647
+ h2 = self.slice2(h1)
1648
+ h3 = self.slice3(h2)
1649
+ h4 = self.slice4(h3)
1650
+ return [h1, h2, h3, h4]
1651
+
1652
+ loss_network = MultiLayerVGG().to(self.device).eval()
1653
  mse_loss = nn.MSELoss()
1654
 
1655
  # Training loop
 
1657
  model.encoder.eval() # Keep encoder frozen
1658
  total_steps = 0
1659
 
1660
+ # Adjust style weight for better quality
1661
+ actual_style_weight = style_weight * 10 # Multiply by 10 for better style transfer
1662
+
1663
  for epoch in range(epochs):
1664
  epoch_loss = 0
1665
 
 
1676
  # Forward pass
1677
  output = model(content_batch, batch_style)
1678
 
1679
+ # Multi-layer content and style loss
1680
  with torch.no_grad():
1681
+ content_feats = loss_network(content_batch)
1682
+ style_feats = loss_network(batch_style)
1683
+ output_feats = loss_network(output)
1684
 
1685
+ # Content loss - only from relu4_1
1686
+ content_loss = mse_loss(output_feats[-1], content_feats[-1])
1687
+
1688
+ # Style loss - from multiple layers
1689
+ style_loss = 0
1690
+ style_weights = [0.2, 0.3, 0.5, 1.0] # Give more weight to higher layers
1691
 
 
1692
  def gram_matrix(feat):
1693
  b, c, h, w = feat.size()
1694
  feat = feat.view(b, c, h * w)
1695
  gram = torch.bmm(feat, feat.transpose(1, 2))
1696
  return gram / (c * h * w)
1697
 
1698
+ for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)):
1699
+ output_gram = gram_matrix(output_feat)
1700
+ style_gram = gram_matrix(style_feat)
1701
+ style_loss += weight * mse_loss(output_gram, style_gram)
1702
+
1703
+ style_loss /= len(style_weights)
1704
 
1705
  # Total loss
1706
+ loss = content_weight * content_loss + actual_style_weight * style_loss
1707
 
1708
  # Backward pass
1709
  optimizer.zero_grad()
1710
  loss.backward()
1711
+
1712
+ # Gradient clipping for stability
1713
+ torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0)
1714
+
1715
  optimizer.step()
1716
 
1717
  epoch_loss += loss.item()
 
1720
  # Progress callback
1721
  if progress_callback and total_steps % 10 == 0:
1722
  progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
1723
+ progress_callback(progress,
1724
+ f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} "
1725
+ f"(Content: {content_loss.item():.4f}, Style: {style_loss.item():.4f})")
1726
+
1727
+ # Step scheduler
1728
+ scheduler.step()
1729
 
1730
  # Save checkpoint
1731
  if (epoch + 1) % save_interval == 0:
 
1734
  'epoch': epoch + 1,
1735
  'model_state_dict': model.state_dict(),
1736
  'optimizer_state_dict': optimizer.state_dict(),
1737
+ 'scheduler_state_dict': scheduler.state_dict(),
1738
  'loss': epoch_loss / len(dataloader),
1739
  'model_type': 'adain'
1740
  }, checkpoint_path)
 
1752
  self.lightweight_models[model_name] = model
1753
 
1754
  return model
1755
+
1756
+
1757
+ # Update these methods in your StyleTransferSystem class:
1758
 
1759
  def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False):
1760
  """Apply AdaIN-based style transfer with optional tiling"""
1761
+ # Use tiling for large images to maintain quality
1762
+ if use_tiling and (content_image.width > 768 or content_image.height > 768):
1763
  return self.apply_adain_style_tiled(
1764
  content_image, style_image, model, alpha,
1765
+ tile_size=512, # Increased from 256
1766
+ overlap=64, # Increased overlap
1767
  blend_mode='gaussian'
1768
  )
1769
 
 
1770
  if content_image is None or style_image is None or model is None:
1771
  return None
1772
 
 
1776
 
1777
  original_size = content_image.size
1778
 
1779
+ # Use higher resolution - find optimal size while maintaining aspect ratio
1780
+ max_dim = 768 # Increased from 256
1781
+ w, h = content_image.size
1782
+ if w > h:
1783
+ new_w = min(w, max_dim)
1784
+ new_h = int(h * new_w / w)
1785
+ else:
1786
+ new_h = min(h, max_dim)
1787
+ new_w = int(w * new_h / h)
1788
+
1789
+ # Ensure dimensions are divisible by 8 for better compatibility
1790
+ new_w = (new_w // 8) * 8
1791
+ new_h = (new_h // 8) * 8
1792
+
1793
  # Transform for AdaIN (VGG normalization)
1794
  transform = transforms.Compose([
1795
+ transforms.Resize((new_h, new_w)),
1796
  transforms.ToTensor(),
1797
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
1798
  std=[0.229, 0.224, 0.225])
 
1820
  print(f"Error applying AdaIN style: {e}")
1821
  traceback.print_exc()
1822
  return None
1823
+
1824
  def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0,
1825
  tile_size=256, overlap=32, blend_mode='linear'):
1826
  """
 
1834
  model = model.to(self.device)
1835
  model.eval()
1836
 
1837
+ # INCREASED TILE SIZE FOR BETTER QUALITY
1838
+ tile_size = 512 # Override input to use 512
1839
+ overlap = 64 # Increase overlap proportionally
1840
+
1841
  # Prepare transforms
1842
  transform = transforms.Compose([
1843
  transforms.Resize((tile_size, tile_size)),
 
1858
  tiles_y = list(range(0, h - tile_size + 1, stride))
1859
 
1860
  # Ensure we cover the entire image
1861
+ if not tiles_x or tiles_x[-1] + tile_size < w:
1862
+ tiles_x.append(max(0, w - tile_size))
1863
+ if not tiles_y or tiles_y[-1] + tile_size < h:
1864
+ tiles_y.append(max(0, h - tile_size))
1865
 
1866
  # If image is smaller than tile size, just process normally
1867
  if w <= tile_size and h <= tile_size:
 
1896
  # Convert to numpy
1897
  styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255
1898
 
1899
+ # Create weight mask for blending - use gaussian by default for better quality
1900
+ weight = self._create_gaussian_weight(tile_size, tile_size, overlap)
 
 
 
1901
 
1902
  # Add to output with weights
1903
  output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight
 
2771
 
2772
 
2773
 
2774
+ # TAB 4: Training with AdaIN and Regional Application
2775
  # TAB 4: Training with AdaIN and Regional Application
2776
  with tab4:
2777
  st.header("Train Custom Style with AdaIN")
 
2805
 
2806
  with col2:
2807
  st.subheader("Content Images")
2808
+ content_imgs = st.file_uploader("Upload content images (10-50 recommended)",
2809
  type=['png', 'jpg', 'jpeg'],
2810
  accept_multiple_files=True,
2811
  key="train_content_adain")
 
2829
  model_name = st.text_input("Model Name",
2830
  value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
2831
 
2832
+ # IMPROVED DEFAULT VALUES
2833
+ epochs = st.slider("Training Epochs", 10, 100, 50, 5) # Increased default
2834
  batch_size = st.slider("Batch Size", 1, 8, 4)
2835
  learning_rate = st.number_input("Learning Rate", 0.00001, 0.001, 0.0001, format="%.5f")
2836
 
2837
  with st.expander("Advanced Settings"):
2838
+ # MUCH HIGHER STYLE WEIGHT BY DEFAULT
2839
+ style_weight = st.number_input("Style Weight", 1.0, 1000.0, 100.0, 10.0)
2840
  content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1)
2841
+ save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20, 10, 5)
2842
+
2843
+ st.info("💡 **Pro tip**: For better quality, use Style Weight 100-500x higher than Content Weight")
2844
 
2845
  st.markdown("---")
2846
 
2847
  # Training button
2848
  if st.button("Start AdaIN Training", type="primary", use_container_width=True):
2849
  if style_imgs and content_imgs:
2850
+ if len(content_imgs) < 10:
2851
+ st.warning("For best results, use at least 10 content images")
2852
 
2853
  with st.spinner("Training AdaIN model..."):
2854
  progress_bar = st.progress(0)
 
2873
  style_img = Image.open(style_file).convert('RGB')
2874
  style_images.append(style_img)
2875
 
2876
+ # IMPROVED TRAINING FUNCTION
2877
+ # Multi-layer VGG loss for better quality
2878
+ class MultiLayerVGG(nn.Module):
2879
+ def __init__(self):
2880
+ super().__init__()
2881
+ vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
2882
+ self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1
2883
+ self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1
2884
+ self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1
2885
+ self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1
2886
+ for param in self.parameters():
2887
+ param.requires_grad = False
2888
+
2889
+ def forward(self, x):
2890
+ h1 = self.slice1(x)
2891
+ h2 = self.slice2(h1)
2892
+ h3 = self.slice3(h2)
2893
+ h4 = self.slice4(h3)
2894
+ return [h1, h2, h3, h4]
2895
+
2896
+ # Create model
2897
+ model = AdaINStyleTransfer().to(system.device)
2898
+ optimizer = torch.optim.Adam(model.decoder.parameters(), lr=learning_rate)
2899
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
2900
+
2901
+ print(f"Training AdaIN model at 512x512 resolution")
2902
+ print(f"Training device: {system.device}")
2903
+
2904
+ # Prepare style images - LARGER SIZE
2905
+ style_transform = transforms.Compose([
2906
+ transforms.Resize(600), # Increased size
2907
+ transforms.RandomCrop(512), # Larger crops
2908
+ transforms.RandomHorizontalFlip(p=0.5),
2909
+ transforms.ToTensor(),
2910
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
2911
+ std=[0.229, 0.224, 0.225])
2912
+ ])
2913
+
2914
+ style_tensors = []
2915
+ # Create multiple augmented versions
2916
+ for style_img in style_images:
2917
+ for _ in range(5): # 5 augmented versions per style
2918
+ style_tensor = style_transform(style_img).unsqueeze(0).to(system.device)
2919
+ style_tensors.append(style_tensor)
2920
+
2921
+ # Prepare content dataset - LARGER SIZE
2922
+ content_transform = transforms.Compose([
2923
+ transforms.Resize(600),
2924
+ transforms.RandomCrop(512),
2925
+ transforms.RandomHorizontalFlip(),
2926
+ transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
2927
+ transforms.ToTensor(),
2928
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
2929
+ std=[0.229, 0.224, 0.225])
2930
+ ])
2931
+
2932
+ dataset = StyleTransferDataset(temp_content_dir, transform=content_transform)
2933
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
2934
+
2935
+ # Multi-layer loss network
2936
+ loss_network = MultiLayerVGG().to(system.device).eval()
2937
+ mse_loss = nn.MSELoss()
2938
+
2939
+ # Training loop
2940
+ model.train()
2941
+ model.encoder.eval()
2942
+ total_steps = 0
2943
+
2944
+ # Multiply style weight for better results
2945
+ actual_style_weight = style_weight * 10
2946
+
2947
+ for epoch in range(epochs):
2948
+ epoch_loss = 0
2949
+ epoch_content_loss = 0
2950
+ epoch_style_loss = 0
2951
+
2952
+ for batch_idx, content_batch in enumerate(dataloader):
2953
+ content_batch = content_batch.to(system.device)
2954
+
2955
+ # Randomly select style images
2956
+ batch_style = []
2957
+ for _ in range(content_batch.size(0)):
2958
+ style_idx = np.random.randint(0, len(style_tensors))
2959
+ batch_style.append(style_tensors[style_idx])
2960
+ batch_style = torch.cat(batch_style, dim=0)
2961
+
2962
+ # Forward pass
2963
+ output = model(content_batch, batch_style)
2964
+
2965
+ # Multi-layer loss
2966
+ with torch.no_grad():
2967
+ content_feats = loss_network(content_batch)
2968
+ style_feats = loss_network(batch_style)
2969
+ output_feats = loss_network(output)
2970
+
2971
+ # Content loss from relu4_1
2972
+ content_loss = mse_loss(output_feats[-1], content_feats[-1])
2973
+
2974
+ # Style loss from multiple layers
2975
+ style_loss = 0
2976
+ style_weights = [0.2, 0.3, 0.5, 1.0]
2977
+
2978
+ def gram_matrix(feat):
2979
+ b, c, h, w = feat.size()
2980
+ feat = feat.view(b, c, h * w)
2981
+ gram = torch.bmm(feat, feat.transpose(1, 2))
2982
+ return gram / (c * h * w)
2983
+
2984
+ for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)):
2985
+ output_gram = gram_matrix(output_feat)
2986
+ style_gram = gram_matrix(style_feat)
2987
+ style_loss += weight * mse_loss(output_gram, style_gram)
2988
+
2989
+ style_loss /= len(style_weights)
2990
+
2991
+ # Total loss
2992
+ loss = content_weight * content_loss + actual_style_weight * style_loss
2993
+
2994
+ # Backward pass
2995
+ optimizer.zero_grad()
2996
+ loss.backward()
2997
+ torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0)
2998
+ optimizer.step()
2999
+
3000
+ epoch_loss += loss.item()
3001
+ epoch_content_loss += content_loss.item()
3002
+ epoch_style_loss += style_loss.item()
3003
+ total_steps += 1
3004
+
3005
+ # Progress callback
3006
+ if progress_callback and total_steps % 10 == 0:
3007
+ progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
3008
+ progress_callback(progress,
3009
+ f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} "
3010
+ f"(C: {content_loss.item():.4f}, S: {style_loss.item():.4f})")
3011
+
3012
+ # Step scheduler
3013
+ scheduler.step()
3014
+
3015
+ # Print epoch stats
3016
+ avg_loss = epoch_loss / len(dataloader)
3017
+ print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, "
3018
+ f"Content={epoch_content_loss/len(dataloader):.4f}, "
3019
+ f"Style={epoch_style_loss/len(dataloader):.4f}")
3020
+
3021
+ # Save checkpoint
3022
+ if (epoch + 1) % save_interval == 0:
3023
+ checkpoint_path = f'{system.models_dir}/{model_name}_epoch_{epoch+1}.pth'
3024
+ torch.save({
3025
+ 'epoch': epoch + 1,
3026
+ 'model_state_dict': model.state_dict(),
3027
+ 'optimizer_state_dict': optimizer.state_dict(),
3028
+ 'scheduler_state_dict': scheduler.state_dict(),
3029
+ 'loss': avg_loss,
3030
+ 'model_type': 'adain'
3031
+ }, checkpoint_path)
3032
+ print(f"Saved checkpoint: {checkpoint_path}")
3033
+
3034
+ # Save final model
3035
+ final_path = f'{system.models_dir}/{model_name}_final.pth'
3036
+ torch.save({
3037
+ 'model_state_dict': model.state_dict(),
3038
+ 'model_type': 'adain'
3039
+ }, final_path)
3040
 
3041
  # Cleanup
3042
  shutil.rmtree(temp_content_dir)
 
3044
  if model:
3045
  st.session_state['trained_adain_model'] = model
3046
  st.session_state['trained_style_images'] = style_images
3047
+ st.session_state['model_path'] = final_path
3048
+ st.success("AdaIN training complete! 🎉")
3049
+
3050
+ # Add to system's models
3051
+ system.lightweight_models[model_name] = model
3052
 
3053
  progress_bar.empty()
3054
  status_text.empty()
 
3114
  else:
3115
  test_style = None
3116
 
3117
+ # IMPROVED DEFAULTS
3118
  # Alpha blending control
3119
+ alpha = st.slider("Style Strength (Alpha)", 0.0, 2.0, 1.2, 0.1,
3120
+ help="0 = original content, 1 = full style transfer, >1 = stronger style")
3121
 
3122
+ # Add tiling option - DEFAULT TO TRUE
3123
  use_tiling = st.checkbox("Use Tiled Processing",
3124
+ value=True, # Default to True
3125
+ help="Process images in tiles for better quality. Recommended for ALL images.")
3126
 
3127
  # Initialize variables with default values
3128
  brush_size = 30
 
3233
  caption="Styled Result",
3234
  use_column_width=True)
3235
 
3236
+ # Quality tips
3237
+ with st.expander("💡 Tips for Better Quality"):
3238
+ st.markdown("""
3239
+ - **Always use tiling** for best quality
3240
+ - Try **alpha > 1.0** (1.2-1.5) for stronger style
3241
+ - Use **multiple style images** when training
3242
+ - Train for **50+ epochs** for best results
3243
+ - If quality is still poor, retrain with **style weight = 200-500**
3244
+ """)
3245
+
3246
  # Download button
3247
  buf = io.BytesIO()
3248
  st.session_state['adain_styled_result'].save(buf, format='PNG')
 
3262
  st.download_button(
3263
  label="Download Trained AdaIN Model",
3264
  data=f.read(),
3265
+ file_name=f"{model_name}_final.pth",
3266
  mime="application/octet-stream",
3267
  use_container_width=True
3268
  )