Spaces:
Build error
Build error
Daniel Roxas commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -1570,14 +1570,17 @@ class StyleTransferSystem:
|
|
| 1570 |
return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
|
| 1571 |
|
| 1572 |
def train_adain_model(self, style_images, content_dir, model_name,
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
|
| 1576 |
"""Train an AdaIN-based style transfer model"""
|
| 1577 |
|
| 1578 |
model = AdaINStyleTransfer().to(self.device)
|
| 1579 |
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
|
| 1580 |
|
|
|
|
|
|
|
|
|
|
| 1581 |
print(f"Training AdaIN model")
|
| 1582 |
print(f"Training device: {self.device}")
|
| 1583 |
|
|
@@ -1586,24 +1589,32 @@ class StyleTransferSystem:
|
|
| 1586 |
print(f"Model on GPU: {next(model.decoder.parameters()).device}")
|
| 1587 |
print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
| 1588 |
|
| 1589 |
-
# Prepare style images
|
| 1590 |
style_transform = transforms.Compose([
|
| 1591 |
-
transforms.Resize(
|
| 1592 |
-
transforms.RandomCrop(
|
|
|
|
| 1593 |
transforms.ToTensor(),
|
| 1594 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1595 |
std=[0.229, 0.224, 0.225])
|
| 1596 |
])
|
| 1597 |
|
| 1598 |
style_tensors = []
|
|
|
|
| 1599 |
for style_img in style_images:
|
| 1600 |
-
|
| 1601 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1602 |
|
| 1603 |
-
# Prepare content dataset
|
| 1604 |
content_transform = transforms.Compose([
|
| 1605 |
-
transforms.Resize(
|
| 1606 |
-
transforms.RandomCrop(
|
|
|
|
|
|
|
| 1607 |
transforms.ToTensor(),
|
| 1608 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1609 |
std=[0.229, 0.224, 0.225])
|
|
@@ -1617,9 +1628,28 @@ class StyleTransferSystem:
|
|
| 1617 |
print(f" - Content images: {len(dataset)}")
|
| 1618 |
print(f" - Batch size: {batch_size}")
|
| 1619 |
print(f" - Epochs: {epochs}")
|
| 1620 |
-
|
| 1621 |
-
|
| 1622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1623 |
mse_loss = nn.MSELoss()
|
| 1624 |
|
| 1625 |
# Training loop
|
|
@@ -1627,6 +1657,9 @@ class StyleTransferSystem:
|
|
| 1627 |
model.encoder.eval() # Keep encoder frozen
|
| 1628 |
total_steps = 0
|
| 1629 |
|
|
|
|
|
|
|
|
|
|
| 1630 |
for epoch in range(epochs):
|
| 1631 |
epoch_loss = 0
|
| 1632 |
|
|
@@ -1643,33 +1676,42 @@ class StyleTransferSystem:
|
|
| 1643 |
# Forward pass
|
| 1644 |
output = model(content_batch, batch_style)
|
| 1645 |
|
| 1646 |
-
#
|
| 1647 |
with torch.no_grad():
|
| 1648 |
-
|
| 1649 |
-
|
| 1650 |
-
|
| 1651 |
|
| 1652 |
-
#
|
| 1653 |
-
|
| 1654 |
-
|
|
|
|
|
|
|
|
|
|
| 1655 |
|
| 1656 |
-
# Compute style loss using Gram matrices
|
| 1657 |
def gram_matrix(feat):
|
| 1658 |
b, c, h, w = feat.size()
|
| 1659 |
feat = feat.view(b, c, h * w)
|
| 1660 |
gram = torch.bmm(feat, feat.transpose(1, 2))
|
| 1661 |
return gram / (c * h * w)
|
| 1662 |
|
| 1663 |
-
|
| 1664 |
-
|
| 1665 |
-
|
|
|
|
|
|
|
|
|
|
| 1666 |
|
| 1667 |
# Total loss
|
| 1668 |
-
loss = content_weight * content_loss +
|
| 1669 |
|
| 1670 |
# Backward pass
|
| 1671 |
optimizer.zero_grad()
|
| 1672 |
loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1673 |
optimizer.step()
|
| 1674 |
|
| 1675 |
epoch_loss += loss.item()
|
|
@@ -1678,7 +1720,12 @@ class StyleTransferSystem:
|
|
| 1678 |
# Progress callback
|
| 1679 |
if progress_callback and total_steps % 10 == 0:
|
| 1680 |
progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
|
| 1681 |
-
progress_callback(progress,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1682 |
|
| 1683 |
# Save checkpoint
|
| 1684 |
if (epoch + 1) % save_interval == 0:
|
|
@@ -1687,6 +1734,7 @@ class StyleTransferSystem:
|
|
| 1687 |
'epoch': epoch + 1,
|
| 1688 |
'model_state_dict': model.state_dict(),
|
| 1689 |
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
|
| 1690 |
'loss': epoch_loss / len(dataloader),
|
| 1691 |
'model_type': 'adain'
|
| 1692 |
}, checkpoint_path)
|
|
@@ -1704,19 +1752,21 @@ class StyleTransferSystem:
|
|
| 1704 |
self.lightweight_models[model_name] = model
|
| 1705 |
|
| 1706 |
return model
|
|
|
|
|
|
|
|
|
|
| 1707 |
|
| 1708 |
def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False):
|
| 1709 |
"""Apply AdaIN-based style transfer with optional tiling"""
|
| 1710 |
-
|
| 1711 |
-
|
| 1712 |
return self.apply_adain_style_tiled(
|
| 1713 |
content_image, style_image, model, alpha,
|
| 1714 |
-
tile_size=
|
| 1715 |
-
overlap=
|
| 1716 |
blend_mode='gaussian'
|
| 1717 |
)
|
| 1718 |
|
| 1719 |
-
# Original implementation for small images
|
| 1720 |
if content_image is None or style_image is None or model is None:
|
| 1721 |
return None
|
| 1722 |
|
|
@@ -1726,9 +1776,23 @@ class StyleTransferSystem:
|
|
| 1726 |
|
| 1727 |
original_size = content_image.size
|
| 1728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1729 |
# Transform for AdaIN (VGG normalization)
|
| 1730 |
transform = transforms.Compose([
|
| 1731 |
-
transforms.Resize((
|
| 1732 |
transforms.ToTensor(),
|
| 1733 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1734 |
std=[0.229, 0.224, 0.225])
|
|
@@ -1756,7 +1820,7 @@ class StyleTransferSystem:
|
|
| 1756 |
print(f"Error applying AdaIN style: {e}")
|
| 1757 |
traceback.print_exc()
|
| 1758 |
return None
|
| 1759 |
-
|
| 1760 |
def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0,
|
| 1761 |
tile_size=256, overlap=32, blend_mode='linear'):
|
| 1762 |
"""
|
|
@@ -1770,6 +1834,10 @@ class StyleTransferSystem:
|
|
| 1770 |
model = model.to(self.device)
|
| 1771 |
model.eval()
|
| 1772 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1773 |
# Prepare transforms
|
| 1774 |
transform = transforms.Compose([
|
| 1775 |
transforms.Resize((tile_size, tile_size)),
|
|
@@ -1790,10 +1858,10 @@ class StyleTransferSystem:
|
|
| 1790 |
tiles_y = list(range(0, h - tile_size + 1, stride))
|
| 1791 |
|
| 1792 |
# Ensure we cover the entire image
|
| 1793 |
-
if tiles_x[-1] + tile_size < w:
|
| 1794 |
-
tiles_x.append(w - tile_size)
|
| 1795 |
-
if tiles_y[-1] + tile_size < h:
|
| 1796 |
-
tiles_y.append(h - tile_size)
|
| 1797 |
|
| 1798 |
# If image is smaller than tile size, just process normally
|
| 1799 |
if w <= tile_size and h <= tile_size:
|
|
@@ -1828,11 +1896,8 @@ class StyleTransferSystem:
|
|
| 1828 |
# Convert to numpy
|
| 1829 |
styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255
|
| 1830 |
|
| 1831 |
-
# Create weight mask for blending
|
| 1832 |
-
|
| 1833 |
-
weight = self._create_gaussian_weight(tile_size, tile_size, overlap)
|
| 1834 |
-
else:
|
| 1835 |
-
weight = self._create_linear_weight(tile_size, tile_size, overlap)
|
| 1836 |
|
| 1837 |
# Add to output with weights
|
| 1838 |
output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight
|
|
@@ -2706,6 +2771,7 @@ with tab3:
|
|
| 2706 |
|
| 2707 |
|
| 2708 |
|
|
|
|
| 2709 |
# TAB 4: Training with AdaIN and Regional Application
|
| 2710 |
with tab4:
|
| 2711 |
st.header("Train Custom Style with AdaIN")
|
|
@@ -2739,7 +2805,7 @@ with tab4:
|
|
| 2739 |
|
| 2740 |
with col2:
|
| 2741 |
st.subheader("Content Images")
|
| 2742 |
-
content_imgs = st.file_uploader("Upload content images (
|
| 2743 |
type=['png', 'jpg', 'jpeg'],
|
| 2744 |
accept_multiple_files=True,
|
| 2745 |
key="train_content_adain")
|
|
@@ -2763,22 +2829,26 @@ with tab4:
|
|
| 2763 |
model_name = st.text_input("Model Name",
|
| 2764 |
value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
| 2765 |
|
| 2766 |
-
|
|
|
|
| 2767 |
batch_size = st.slider("Batch Size", 1, 8, 4)
|
| 2768 |
learning_rate = st.number_input("Learning Rate", 0.00001, 0.001, 0.0001, format="%.5f")
|
| 2769 |
|
| 2770 |
with st.expander("Advanced Settings"):
|
| 2771 |
-
|
|
|
|
| 2772 |
content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1)
|
| 2773 |
-
save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20,
|
|
|
|
|
|
|
| 2774 |
|
| 2775 |
st.markdown("---")
|
| 2776 |
|
| 2777 |
# Training button
|
| 2778 |
if st.button("Start AdaIN Training", type="primary", use_container_width=True):
|
| 2779 |
if style_imgs and content_imgs:
|
| 2780 |
-
if len(content_imgs) <
|
| 2781 |
-
st.warning("For best results, use at least
|
| 2782 |
|
| 2783 |
with st.spinner("Training AdaIN model..."):
|
| 2784 |
progress_bar = st.progress(0)
|
|
@@ -2803,14 +2873,170 @@ with tab4:
|
|
| 2803 |
style_img = Image.open(style_file).convert('RGB')
|
| 2804 |
style_images.append(style_img)
|
| 2805 |
|
| 2806 |
-
#
|
| 2807 |
-
|
| 2808 |
-
|
| 2809 |
-
|
| 2810 |
-
|
| 2811 |
-
|
| 2812 |
-
|
| 2813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2814 |
|
| 2815 |
# Cleanup
|
| 2816 |
shutil.rmtree(temp_content_dir)
|
|
@@ -2818,8 +3044,11 @@ with tab4:
|
|
| 2818 |
if model:
|
| 2819 |
st.session_state['trained_adain_model'] = model
|
| 2820 |
st.session_state['trained_style_images'] = style_images
|
| 2821 |
-
st.session_state['model_path'] =
|
| 2822 |
-
st.success("AdaIN training complete")
|
|
|
|
|
|
|
|
|
|
| 2823 |
|
| 2824 |
progress_bar.empty()
|
| 2825 |
status_text.empty()
|
|
@@ -2885,14 +3114,15 @@ with tab4:
|
|
| 2885 |
else:
|
| 2886 |
test_style = None
|
| 2887 |
|
|
|
|
| 2888 |
# Alpha blending control
|
| 2889 |
-
alpha = st.slider("Style Strength (Alpha)", 0.0,
|
| 2890 |
-
help="0 = original content, 1 = full style transfer")
|
| 2891 |
|
| 2892 |
-
# Add tiling option
|
| 2893 |
use_tiling = st.checkbox("Use Tiled Processing",
|
| 2894 |
-
value=True,
|
| 2895 |
-
help="Process
|
| 2896 |
|
| 2897 |
# Initialize variables with default values
|
| 2898 |
brush_size = 30
|
|
@@ -3003,6 +3233,16 @@ with tab4:
|
|
| 3003 |
caption="Styled Result",
|
| 3004 |
use_column_width=True)
|
| 3005 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3006 |
# Download button
|
| 3007 |
buf = io.BytesIO()
|
| 3008 |
st.session_state['adain_styled_result'].save(buf, format='PNG')
|
|
@@ -3022,7 +3262,7 @@ with tab4:
|
|
| 3022 |
st.download_button(
|
| 3023 |
label="Download Trained AdaIN Model",
|
| 3024 |
data=f.read(),
|
| 3025 |
-
file_name=f"{
|
| 3026 |
mime="application/octet-stream",
|
| 3027 |
use_container_width=True
|
| 3028 |
)
|
|
|
|
| 1570 |
return Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
|
| 1571 |
|
| 1572 |
def train_adain_model(self, style_images, content_dir, model_name,
|
| 1573 |
+
epochs=30, batch_size=4, lr=1e-4,
|
| 1574 |
+
save_interval=5, style_weight=10.0, content_weight=1.0,
|
| 1575 |
+
progress_callback=None):
|
| 1576 |
"""Train an AdaIN-based style transfer model"""
|
| 1577 |
|
| 1578 |
model = AdaINStyleTransfer().to(self.device)
|
| 1579 |
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
|
| 1580 |
|
| 1581 |
+
# Add learning rate scheduler
|
| 1582 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
|
| 1583 |
+
|
| 1584 |
print(f"Training AdaIN model")
|
| 1585 |
print(f"Training device: {self.device}")
|
| 1586 |
|
|
|
|
| 1589 |
print(f"Model on GPU: {next(model.decoder.parameters()).device}")
|
| 1590 |
print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
| 1591 |
|
| 1592 |
+
# Prepare style images - INCREASED SIZE
|
| 1593 |
style_transform = transforms.Compose([
|
| 1594 |
+
transforms.Resize(600), # Increased from 512
|
| 1595 |
+
transforms.RandomCrop(512), # Increased from 256
|
| 1596 |
+
transforms.RandomHorizontalFlip(p=0.5), # Add augmentation
|
| 1597 |
transforms.ToTensor(),
|
| 1598 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1599 |
std=[0.229, 0.224, 0.225])
|
| 1600 |
])
|
| 1601 |
|
| 1602 |
style_tensors = []
|
| 1603 |
+
# Create multiple augmented versions of each style image
|
| 1604 |
for style_img in style_images:
|
| 1605 |
+
# Generate 5 augmented versions per style image
|
| 1606 |
+
for _ in range(5):
|
| 1607 |
+
style_tensor = style_transform(style_img).unsqueeze(0).to(self.device)
|
| 1608 |
+
style_tensors.append(style_tensor)
|
| 1609 |
+
|
| 1610 |
+
print(f"Created {len(style_tensors)} augmented style samples from {len(style_images)} images")
|
| 1611 |
|
| 1612 |
+
# Prepare content dataset - INCREASED SIZE
|
| 1613 |
content_transform = transforms.Compose([
|
| 1614 |
+
transforms.Resize(600), # Increased from 512
|
| 1615 |
+
transforms.RandomCrop(512), # Increased from 256
|
| 1616 |
+
transforms.RandomHorizontalFlip(),
|
| 1617 |
+
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
|
| 1618 |
transforms.ToTensor(),
|
| 1619 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1620 |
std=[0.229, 0.224, 0.225])
|
|
|
|
| 1628 |
print(f" - Content images: {len(dataset)}")
|
| 1629 |
print(f" - Batch size: {batch_size}")
|
| 1630 |
print(f" - Epochs: {epochs}")
|
| 1631 |
+
print(f" - Training resolution: 512x512") # Updated
|
| 1632 |
+
|
| 1633 |
+
# Loss network (VGG for perceptual loss) - USE MULTIPLE LAYERS
|
| 1634 |
+
class MultiLayerVGG(nn.Module):
|
| 1635 |
+
def __init__(self):
|
| 1636 |
+
super().__init__()
|
| 1637 |
+
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
|
| 1638 |
+
self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1
|
| 1639 |
+
self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1
|
| 1640 |
+
self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1
|
| 1641 |
+
self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1
|
| 1642 |
+
for param in self.parameters():
|
| 1643 |
+
param.requires_grad = False
|
| 1644 |
+
|
| 1645 |
+
def forward(self, x):
|
| 1646 |
+
h1 = self.slice1(x)
|
| 1647 |
+
h2 = self.slice2(h1)
|
| 1648 |
+
h3 = self.slice3(h2)
|
| 1649 |
+
h4 = self.slice4(h3)
|
| 1650 |
+
return [h1, h2, h3, h4]
|
| 1651 |
+
|
| 1652 |
+
loss_network = MultiLayerVGG().to(self.device).eval()
|
| 1653 |
mse_loss = nn.MSELoss()
|
| 1654 |
|
| 1655 |
# Training loop
|
|
|
|
| 1657 |
model.encoder.eval() # Keep encoder frozen
|
| 1658 |
total_steps = 0
|
| 1659 |
|
| 1660 |
+
# Adjust style weight for better quality
|
| 1661 |
+
actual_style_weight = style_weight * 10 # Multiply by 10 for better style transfer
|
| 1662 |
+
|
| 1663 |
for epoch in range(epochs):
|
| 1664 |
epoch_loss = 0
|
| 1665 |
|
|
|
|
| 1676 |
# Forward pass
|
| 1677 |
output = model(content_batch, batch_style)
|
| 1678 |
|
| 1679 |
+
# Multi-layer content and style loss
|
| 1680 |
with torch.no_grad():
|
| 1681 |
+
content_feats = loss_network(content_batch)
|
| 1682 |
+
style_feats = loss_network(batch_style)
|
| 1683 |
+
output_feats = loss_network(output)
|
| 1684 |
|
| 1685 |
+
# Content loss - only from relu4_1
|
| 1686 |
+
content_loss = mse_loss(output_feats[-1], content_feats[-1])
|
| 1687 |
+
|
| 1688 |
+
# Style loss - from multiple layers
|
| 1689 |
+
style_loss = 0
|
| 1690 |
+
style_weights = [0.2, 0.3, 0.5, 1.0] # Give more weight to higher layers
|
| 1691 |
|
|
|
|
| 1692 |
def gram_matrix(feat):
|
| 1693 |
b, c, h, w = feat.size()
|
| 1694 |
feat = feat.view(b, c, h * w)
|
| 1695 |
gram = torch.bmm(feat, feat.transpose(1, 2))
|
| 1696 |
return gram / (c * h * w)
|
| 1697 |
|
| 1698 |
+
for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)):
|
| 1699 |
+
output_gram = gram_matrix(output_feat)
|
| 1700 |
+
style_gram = gram_matrix(style_feat)
|
| 1701 |
+
style_loss += weight * mse_loss(output_gram, style_gram)
|
| 1702 |
+
|
| 1703 |
+
style_loss /= len(style_weights)
|
| 1704 |
|
| 1705 |
# Total loss
|
| 1706 |
+
loss = content_weight * content_loss + actual_style_weight * style_loss
|
| 1707 |
|
| 1708 |
# Backward pass
|
| 1709 |
optimizer.zero_grad()
|
| 1710 |
loss.backward()
|
| 1711 |
+
|
| 1712 |
+
# Gradient clipping for stability
|
| 1713 |
+
torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0)
|
| 1714 |
+
|
| 1715 |
optimizer.step()
|
| 1716 |
|
| 1717 |
epoch_loss += loss.item()
|
|
|
|
| 1720 |
# Progress callback
|
| 1721 |
if progress_callback and total_steps % 10 == 0:
|
| 1722 |
progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
|
| 1723 |
+
progress_callback(progress,
|
| 1724 |
+
f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} "
|
| 1725 |
+
f"(Content: {content_loss.item():.4f}, Style: {style_loss.item():.4f})")
|
| 1726 |
+
|
| 1727 |
+
# Step scheduler
|
| 1728 |
+
scheduler.step()
|
| 1729 |
|
| 1730 |
# Save checkpoint
|
| 1731 |
if (epoch + 1) % save_interval == 0:
|
|
|
|
| 1734 |
'epoch': epoch + 1,
|
| 1735 |
'model_state_dict': model.state_dict(),
|
| 1736 |
'optimizer_state_dict': optimizer.state_dict(),
|
| 1737 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 1738 |
'loss': epoch_loss / len(dataloader),
|
| 1739 |
'model_type': 'adain'
|
| 1740 |
}, checkpoint_path)
|
|
|
|
| 1752 |
self.lightweight_models[model_name] = model
|
| 1753 |
|
| 1754 |
return model
|
| 1755 |
+
|
| 1756 |
+
|
| 1757 |
+
# Update these methods in your StyleTransferSystem class:
|
| 1758 |
|
| 1759 |
def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False):
|
| 1760 |
"""Apply AdaIN-based style transfer with optional tiling"""
|
| 1761 |
+
# Use tiling for large images to maintain quality
|
| 1762 |
+
if use_tiling and (content_image.width > 768 or content_image.height > 768):
|
| 1763 |
return self.apply_adain_style_tiled(
|
| 1764 |
content_image, style_image, model, alpha,
|
| 1765 |
+
tile_size=512, # Increased from 256
|
| 1766 |
+
overlap=64, # Increased overlap
|
| 1767 |
blend_mode='gaussian'
|
| 1768 |
)
|
| 1769 |
|
|
|
|
| 1770 |
if content_image is None or style_image is None or model is None:
|
| 1771 |
return None
|
| 1772 |
|
|
|
|
| 1776 |
|
| 1777 |
original_size = content_image.size
|
| 1778 |
|
| 1779 |
+
# Use higher resolution - find optimal size while maintaining aspect ratio
|
| 1780 |
+
max_dim = 768 # Increased from 256
|
| 1781 |
+
w, h = content_image.size
|
| 1782 |
+
if w > h:
|
| 1783 |
+
new_w = min(w, max_dim)
|
| 1784 |
+
new_h = int(h * new_w / w)
|
| 1785 |
+
else:
|
| 1786 |
+
new_h = min(h, max_dim)
|
| 1787 |
+
new_w = int(w * new_h / h)
|
| 1788 |
+
|
| 1789 |
+
# Ensure dimensions are divisible by 8 for better compatibility
|
| 1790 |
+
new_w = (new_w // 8) * 8
|
| 1791 |
+
new_h = (new_h // 8) * 8
|
| 1792 |
+
|
| 1793 |
# Transform for AdaIN (VGG normalization)
|
| 1794 |
transform = transforms.Compose([
|
| 1795 |
+
transforms.Resize((new_h, new_w)),
|
| 1796 |
transforms.ToTensor(),
|
| 1797 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1798 |
std=[0.229, 0.224, 0.225])
|
|
|
|
| 1820 |
print(f"Error applying AdaIN style: {e}")
|
| 1821 |
traceback.print_exc()
|
| 1822 |
return None
|
| 1823 |
+
|
| 1824 |
def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0,
|
| 1825 |
tile_size=256, overlap=32, blend_mode='linear'):
|
| 1826 |
"""
|
|
|
|
| 1834 |
model = model.to(self.device)
|
| 1835 |
model.eval()
|
| 1836 |
|
| 1837 |
+
# INCREASED TILE SIZE FOR BETTER QUALITY
|
| 1838 |
+
tile_size = 512 # Override input to use 512
|
| 1839 |
+
overlap = 64 # Increase overlap proportionally
|
| 1840 |
+
|
| 1841 |
# Prepare transforms
|
| 1842 |
transform = transforms.Compose([
|
| 1843 |
transforms.Resize((tile_size, tile_size)),
|
|
|
|
| 1858 |
tiles_y = list(range(0, h - tile_size + 1, stride))
|
| 1859 |
|
| 1860 |
# Ensure we cover the entire image
|
| 1861 |
+
if not tiles_x or tiles_x[-1] + tile_size < w:
|
| 1862 |
+
tiles_x.append(max(0, w - tile_size))
|
| 1863 |
+
if not tiles_y or tiles_y[-1] + tile_size < h:
|
| 1864 |
+
tiles_y.append(max(0, h - tile_size))
|
| 1865 |
|
| 1866 |
# If image is smaller than tile size, just process normally
|
| 1867 |
if w <= tile_size and h <= tile_size:
|
|
|
|
| 1896 |
# Convert to numpy
|
| 1897 |
styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255
|
| 1898 |
|
| 1899 |
+
# Create weight mask for blending - use gaussian by default for better quality
|
| 1900 |
+
weight = self._create_gaussian_weight(tile_size, tile_size, overlap)
|
|
|
|
|
|
|
|
|
|
| 1901 |
|
| 1902 |
# Add to output with weights
|
| 1903 |
output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight
|
|
|
|
| 2771 |
|
| 2772 |
|
| 2773 |
|
| 2774 |
+
# TAB 4: Training with AdaIN and Regional Application
|
| 2775 |
# TAB 4: Training with AdaIN and Regional Application
|
| 2776 |
with tab4:
|
| 2777 |
st.header("Train Custom Style with AdaIN")
|
|
|
|
| 2805 |
|
| 2806 |
with col2:
|
| 2807 |
st.subheader("Content Images")
|
| 2808 |
+
content_imgs = st.file_uploader("Upload content images (10-50 recommended)",
|
| 2809 |
type=['png', 'jpg', 'jpeg'],
|
| 2810 |
accept_multiple_files=True,
|
| 2811 |
key="train_content_adain")
|
|
|
|
| 2829 |
model_name = st.text_input("Model Name",
|
| 2830 |
value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
| 2831 |
|
| 2832 |
+
# IMPROVED DEFAULT VALUES
|
| 2833 |
+
epochs = st.slider("Training Epochs", 10, 100, 50, 5) # Increased default
|
| 2834 |
batch_size = st.slider("Batch Size", 1, 8, 4)
|
| 2835 |
learning_rate = st.number_input("Learning Rate", 0.00001, 0.001, 0.0001, format="%.5f")
|
| 2836 |
|
| 2837 |
with st.expander("Advanced Settings"):
|
| 2838 |
+
# MUCH HIGHER STYLE WEIGHT BY DEFAULT
|
| 2839 |
+
style_weight = st.number_input("Style Weight", 1.0, 1000.0, 100.0, 10.0)
|
| 2840 |
content_weight = st.number_input("Content Weight", 0.1, 10.0, 1.0, 0.1)
|
| 2841 |
+
save_interval = st.slider("Save Checkpoint Every N Epochs", 5, 20, 10, 5)
|
| 2842 |
+
|
| 2843 |
+
st.info("💡 **Pro tip**: For better quality, use Style Weight 100-500x higher than Content Weight")
|
| 2844 |
|
| 2845 |
st.markdown("---")
|
| 2846 |
|
| 2847 |
# Training button
|
| 2848 |
if st.button("Start AdaIN Training", type="primary", use_container_width=True):
|
| 2849 |
if style_imgs and content_imgs:
|
| 2850 |
+
if len(content_imgs) < 10:
|
| 2851 |
+
st.warning("For best results, use at least 10 content images")
|
| 2852 |
|
| 2853 |
with st.spinner("Training AdaIN model..."):
|
| 2854 |
progress_bar = st.progress(0)
|
|
|
|
| 2873 |
style_img = Image.open(style_file).convert('RGB')
|
| 2874 |
style_images.append(style_img)
|
| 2875 |
|
| 2876 |
+
# IMPROVED TRAINING FUNCTION
|
| 2877 |
+
# Multi-layer VGG loss for better quality
|
| 2878 |
+
class MultiLayerVGG(nn.Module):
|
| 2879 |
+
def __init__(self):
|
| 2880 |
+
super().__init__()
|
| 2881 |
+
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
|
| 2882 |
+
self.slice1 = nn.Sequential(*list(vgg.children())[:2]) # relu1_1
|
| 2883 |
+
self.slice2 = nn.Sequential(*list(vgg.children())[2:7]) # relu2_1
|
| 2884 |
+
self.slice3 = nn.Sequential(*list(vgg.children())[7:12]) # relu3_1
|
| 2885 |
+
self.slice4 = nn.Sequential(*list(vgg.children())[12:21]) # relu4_1
|
| 2886 |
+
for param in self.parameters():
|
| 2887 |
+
param.requires_grad = False
|
| 2888 |
+
|
| 2889 |
+
def forward(self, x):
|
| 2890 |
+
h1 = self.slice1(x)
|
| 2891 |
+
h2 = self.slice2(h1)
|
| 2892 |
+
h3 = self.slice3(h2)
|
| 2893 |
+
h4 = self.slice4(h3)
|
| 2894 |
+
return [h1, h2, h3, h4]
|
| 2895 |
+
|
| 2896 |
+
# Create model
|
| 2897 |
+
model = AdaINStyleTransfer().to(system.device)
|
| 2898 |
+
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=learning_rate)
|
| 2899 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
|
| 2900 |
+
|
| 2901 |
+
print(f"Training AdaIN model at 512x512 resolution")
|
| 2902 |
+
print(f"Training device: {system.device}")
|
| 2903 |
+
|
| 2904 |
+
# Prepare style images - LARGER SIZE
|
| 2905 |
+
style_transform = transforms.Compose([
|
| 2906 |
+
transforms.Resize(600), # Increased size
|
| 2907 |
+
transforms.RandomCrop(512), # Larger crops
|
| 2908 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 2909 |
+
transforms.ToTensor(),
|
| 2910 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 2911 |
+
std=[0.229, 0.224, 0.225])
|
| 2912 |
+
])
|
| 2913 |
+
|
| 2914 |
+
style_tensors = []
|
| 2915 |
+
# Create multiple augmented versions
|
| 2916 |
+
for style_img in style_images:
|
| 2917 |
+
for _ in range(5): # 5 augmented versions per style
|
| 2918 |
+
style_tensor = style_transform(style_img).unsqueeze(0).to(system.device)
|
| 2919 |
+
style_tensors.append(style_tensor)
|
| 2920 |
+
|
| 2921 |
+
# Prepare content dataset - LARGER SIZE
|
| 2922 |
+
content_transform = transforms.Compose([
|
| 2923 |
+
transforms.Resize(600),
|
| 2924 |
+
transforms.RandomCrop(512),
|
| 2925 |
+
transforms.RandomHorizontalFlip(),
|
| 2926 |
+
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
|
| 2927 |
+
transforms.ToTensor(),
|
| 2928 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 2929 |
+
std=[0.229, 0.224, 0.225])
|
| 2930 |
+
])
|
| 2931 |
+
|
| 2932 |
+
dataset = StyleTransferDataset(temp_content_dir, transform=content_transform)
|
| 2933 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
|
| 2934 |
+
|
| 2935 |
+
# Multi-layer loss network
|
| 2936 |
+
loss_network = MultiLayerVGG().to(system.device).eval()
|
| 2937 |
+
mse_loss = nn.MSELoss()
|
| 2938 |
+
|
| 2939 |
+
# Training loop
|
| 2940 |
+
model.train()
|
| 2941 |
+
model.encoder.eval()
|
| 2942 |
+
total_steps = 0
|
| 2943 |
+
|
| 2944 |
+
# Multiply style weight for better results
|
| 2945 |
+
actual_style_weight = style_weight * 10
|
| 2946 |
+
|
| 2947 |
+
for epoch in range(epochs):
|
| 2948 |
+
epoch_loss = 0
|
| 2949 |
+
epoch_content_loss = 0
|
| 2950 |
+
epoch_style_loss = 0
|
| 2951 |
+
|
| 2952 |
+
for batch_idx, content_batch in enumerate(dataloader):
|
| 2953 |
+
content_batch = content_batch.to(system.device)
|
| 2954 |
+
|
| 2955 |
+
# Randomly select style images
|
| 2956 |
+
batch_style = []
|
| 2957 |
+
for _ in range(content_batch.size(0)):
|
| 2958 |
+
style_idx = np.random.randint(0, len(style_tensors))
|
| 2959 |
+
batch_style.append(style_tensors[style_idx])
|
| 2960 |
+
batch_style = torch.cat(batch_style, dim=0)
|
| 2961 |
+
|
| 2962 |
+
# Forward pass
|
| 2963 |
+
output = model(content_batch, batch_style)
|
| 2964 |
+
|
| 2965 |
+
# Multi-layer loss
|
| 2966 |
+
with torch.no_grad():
|
| 2967 |
+
content_feats = loss_network(content_batch)
|
| 2968 |
+
style_feats = loss_network(batch_style)
|
| 2969 |
+
output_feats = loss_network(output)
|
| 2970 |
+
|
| 2971 |
+
# Content loss from relu4_1
|
| 2972 |
+
content_loss = mse_loss(output_feats[-1], content_feats[-1])
|
| 2973 |
+
|
| 2974 |
+
# Style loss from multiple layers
|
| 2975 |
+
style_loss = 0
|
| 2976 |
+
style_weights = [0.2, 0.3, 0.5, 1.0]
|
| 2977 |
+
|
| 2978 |
+
def gram_matrix(feat):
|
| 2979 |
+
b, c, h, w = feat.size()
|
| 2980 |
+
feat = feat.view(b, c, h * w)
|
| 2981 |
+
gram = torch.bmm(feat, feat.transpose(1, 2))
|
| 2982 |
+
return gram / (c * h * w)
|
| 2983 |
+
|
| 2984 |
+
for i, (output_feat, style_feat, weight) in enumerate(zip(output_feats, style_feats, style_weights)):
|
| 2985 |
+
output_gram = gram_matrix(output_feat)
|
| 2986 |
+
style_gram = gram_matrix(style_feat)
|
| 2987 |
+
style_loss += weight * mse_loss(output_gram, style_gram)
|
| 2988 |
+
|
| 2989 |
+
style_loss /= len(style_weights)
|
| 2990 |
+
|
| 2991 |
+
# Total loss
|
| 2992 |
+
loss = content_weight * content_loss + actual_style_weight * style_loss
|
| 2993 |
+
|
| 2994 |
+
# Backward pass
|
| 2995 |
+
optimizer.zero_grad()
|
| 2996 |
+
loss.backward()
|
| 2997 |
+
torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), max_norm=5.0)
|
| 2998 |
+
optimizer.step()
|
| 2999 |
+
|
| 3000 |
+
epoch_loss += loss.item()
|
| 3001 |
+
epoch_content_loss += content_loss.item()
|
| 3002 |
+
epoch_style_loss += style_loss.item()
|
| 3003 |
+
total_steps += 1
|
| 3004 |
+
|
| 3005 |
+
# Progress callback
|
| 3006 |
+
if progress_callback and total_steps % 10 == 0:
|
| 3007 |
+
progress = (epoch + (batch_idx + 1) / len(dataloader)) / epochs
|
| 3008 |
+
progress_callback(progress,
|
| 3009 |
+
f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f} "
|
| 3010 |
+
f"(C: {content_loss.item():.4f}, S: {style_loss.item():.4f})")
|
| 3011 |
+
|
| 3012 |
+
# Step scheduler
|
| 3013 |
+
scheduler.step()
|
| 3014 |
+
|
| 3015 |
+
# Print epoch stats
|
| 3016 |
+
avg_loss = epoch_loss / len(dataloader)
|
| 3017 |
+
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, "
|
| 3018 |
+
f"Content={epoch_content_loss/len(dataloader):.4f}, "
|
| 3019 |
+
f"Style={epoch_style_loss/len(dataloader):.4f}")
|
| 3020 |
+
|
| 3021 |
+
# Save checkpoint
|
| 3022 |
+
if (epoch + 1) % save_interval == 0:
|
| 3023 |
+
checkpoint_path = f'{system.models_dir}/{model_name}_epoch_{epoch+1}.pth'
|
| 3024 |
+
torch.save({
|
| 3025 |
+
'epoch': epoch + 1,
|
| 3026 |
+
'model_state_dict': model.state_dict(),
|
| 3027 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 3028 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 3029 |
+
'loss': avg_loss,
|
| 3030 |
+
'model_type': 'adain'
|
| 3031 |
+
}, checkpoint_path)
|
| 3032 |
+
print(f"Saved checkpoint: {checkpoint_path}")
|
| 3033 |
+
|
| 3034 |
+
# Save final model
|
| 3035 |
+
final_path = f'{system.models_dir}/{model_name}_final.pth'
|
| 3036 |
+
torch.save({
|
| 3037 |
+
'model_state_dict': model.state_dict(),
|
| 3038 |
+
'model_type': 'adain'
|
| 3039 |
+
}, final_path)
|
| 3040 |
|
| 3041 |
# Cleanup
|
| 3042 |
shutil.rmtree(temp_content_dir)
|
|
|
|
| 3044 |
if model:
|
| 3045 |
st.session_state['trained_adain_model'] = model
|
| 3046 |
st.session_state['trained_style_images'] = style_images
|
| 3047 |
+
st.session_state['model_path'] = final_path
|
| 3048 |
+
st.success("AdaIN training complete! 🎉")
|
| 3049 |
+
|
| 3050 |
+
# Add to system's models
|
| 3051 |
+
system.lightweight_models[model_name] = model
|
| 3052 |
|
| 3053 |
progress_bar.empty()
|
| 3054 |
status_text.empty()
|
|
|
|
| 3114 |
else:
|
| 3115 |
test_style = None
|
| 3116 |
|
| 3117 |
+
# IMPROVED DEFAULTS
|
| 3118 |
# Alpha blending control
|
| 3119 |
+
alpha = st.slider("Style Strength (Alpha)", 0.0, 2.0, 1.2, 0.1,
|
| 3120 |
+
help="0 = original content, 1 = full style transfer, >1 = stronger style")
|
| 3121 |
|
| 3122 |
+
# Add tiling option - DEFAULT TO TRUE
|
| 3123 |
use_tiling = st.checkbox("Use Tiled Processing",
|
| 3124 |
+
value=True, # Default to True
|
| 3125 |
+
help="Process images in tiles for better quality. Recommended for ALL images.")
|
| 3126 |
|
| 3127 |
# Initialize variables with default values
|
| 3128 |
brush_size = 30
|
|
|
|
| 3233 |
caption="Styled Result",
|
| 3234 |
use_column_width=True)
|
| 3235 |
|
| 3236 |
+
# Quality tips
|
| 3237 |
+
with st.expander("💡 Tips for Better Quality"):
|
| 3238 |
+
st.markdown("""
|
| 3239 |
+
- **Always use tiling** for best quality
|
| 3240 |
+
- Try **alpha > 1.0** (1.2-1.5) for stronger style
|
| 3241 |
+
- Use **multiple style images** when training
|
| 3242 |
+
- Train for **50+ epochs** for best results
|
| 3243 |
+
- If quality is still poor, retrain with **style weight = 200-500**
|
| 3244 |
+
""")
|
| 3245 |
+
|
| 3246 |
# Download button
|
| 3247 |
buf = io.BytesIO()
|
| 3248 |
st.session_state['adain_styled_result'].save(buf, format='PNG')
|
|
|
|
| 3262 |
st.download_button(
|
| 3263 |
label="Download Trained AdaIN Model",
|
| 3264 |
data=f.read(),
|
| 3265 |
+
file_name=f"{model_name}_final.pth",
|
| 3266 |
mime="application/octet-stream",
|
| 3267 |
use_container_width=True
|
| 3268 |
)
|