Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -46,7 +46,7 @@ warnings.filterwarnings("ignore")
|
|
| 46 |
# Set page config
|
| 47 |
st.set_page_config(
|
| 48 |
page_title="Style Transfer Studio",
|
| 49 |
-
# page_icon="
|
| 50 |
layout="wide",
|
| 51 |
initial_sidebar_state="expanded"
|
| 52 |
)
|
|
@@ -1344,6 +1344,56 @@ class StyleTransferSystem:
|
|
| 1344 |
return model
|
| 1345 |
except:
|
| 1346 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1347 |
|
| 1348 |
def apply_lightweight_style(self, image, model, intensity=1.0):
|
| 1349 |
"""Apply style using a lightweight model"""
|
|
@@ -1654,8 +1704,18 @@ class StyleTransferSystem:
|
|
| 1654 |
|
| 1655 |
return model
|
| 1656 |
|
| 1657 |
-
def apply_adain_style(self, content_image, style_image, model, alpha=1.0):
|
| 1658 |
-
"""Apply AdaIN-based style transfer with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1659 |
if content_image is None or style_image is None or model is None:
|
| 1660 |
return None
|
| 1661 |
|
|
@@ -1665,32 +1725,9 @@ class StyleTransferSystem:
|
|
| 1665 |
|
| 1666 |
original_size = content_image.size
|
| 1667 |
|
| 1668 |
-
#
|
| 1669 |
-
if self.device.type == 'cuda':
|
| 1670 |
-
# Try to use higher resolution on GPU
|
| 1671 |
-
max_dimension = 768 # Increased from 256
|
| 1672 |
-
|
| 1673 |
-
# Calculate memory-efficient size
|
| 1674 |
-
aspect_ratio = content_image.width / content_image.height
|
| 1675 |
-
if content_image.width > content_image.height:
|
| 1676 |
-
process_width = min(max_dimension, content_image.width)
|
| 1677 |
-
process_height = int(process_width / aspect_ratio)
|
| 1678 |
-
else:
|
| 1679 |
-
process_height = min(max_dimension, content_image.height)
|
| 1680 |
-
process_width = int(process_height * aspect_ratio)
|
| 1681 |
-
|
| 1682 |
-
# Round to multiples of 32 for better GPU efficiency
|
| 1683 |
-
process_width = ((process_width + 31) // 32) * 32
|
| 1684 |
-
process_height = ((process_height + 31) // 32) * 32
|
| 1685 |
-
else:
|
| 1686 |
-
# Lower resolution for CPU
|
| 1687 |
-
process_width, process_height = 512, 512
|
| 1688 |
-
|
| 1689 |
-
print(f"Processing at {process_width}x{process_height} (original: {original_size})")
|
| 1690 |
-
|
| 1691 |
-
# Transform without aggressive cropping
|
| 1692 |
transform = transforms.Compose([
|
| 1693 |
-
transforms.Resize((
|
| 1694 |
transforms.ToTensor(),
|
| 1695 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1696 |
std=[0.229, 0.224, 0.225])
|
|
@@ -1700,71 +1737,117 @@ class StyleTransferSystem:
|
|
| 1700 |
style_tensor = transform(style_image).unsqueeze(0).to(self.device)
|
| 1701 |
|
| 1702 |
with torch.no_grad():
|
| 1703 |
-
# Process in chunks if image is very large
|
| 1704 |
-
if process_width * process_height > 512 * 512 and self.device.type == 'cuda':
|
| 1705 |
-
# Clear cache before processing large image
|
| 1706 |
-
torch.cuda.empty_cache()
|
| 1707 |
-
|
| 1708 |
output = model(content_tensor, style_tensor, alpha=alpha)
|
| 1709 |
|
| 1710 |
-
#
|
| 1711 |
output = output.squeeze(0).cpu()
|
| 1712 |
-
|
| 1713 |
-
|
| 1714 |
-
denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 1715 |
-
denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 1716 |
-
output = output * denorm_std + denorm_mean
|
| 1717 |
-
|
| 1718 |
-
# Ensure values are in valid range
|
| 1719 |
output = torch.clamp(output, 0, 1)
|
| 1720 |
|
| 1721 |
-
# Convert to PIL
|
| 1722 |
output_img = transforms.ToPILImage()(output)
|
| 1723 |
-
|
| 1724 |
-
# Use LANCZOS for high-quality resizing
|
| 1725 |
-
if output_img.size != original_size:
|
| 1726 |
-
output_img = output_img.resize(original_size, Image.LANCZOS)
|
| 1727 |
-
|
| 1728 |
-
# Optional: Sharpen slightly to compensate for softness
|
| 1729 |
-
if hasattr(Image, 'UnsharpMask'):
|
| 1730 |
-
from PIL import ImageFilter
|
| 1731 |
-
output_img = output_img.filter(ImageFilter.UnsharpMask(radius=1, percent=50, threshold=0))
|
| 1732 |
|
| 1733 |
return output_img
|
| 1734 |
|
| 1735 |
-
except
|
| 1736 |
-
|
| 1737 |
-
|
| 1738 |
-
|
| 1739 |
-
|
| 1740 |
-
|
| 1741 |
-
|
| 1742 |
-
|
| 1743 |
-
|
| 1744 |
-
|
| 1745 |
-
|
| 1746 |
-
|
| 1747 |
-
|
| 1748 |
-
|
| 1749 |
-
|
| 1750 |
-
|
| 1751 |
-
|
| 1752 |
-
|
| 1753 |
-
|
| 1754 |
-
|
| 1755 |
-
|
| 1756 |
-
|
| 1757 |
-
|
| 1758 |
-
|
| 1759 |
-
|
| 1760 |
-
|
| 1761 |
-
|
| 1762 |
-
|
| 1763 |
-
|
| 1764 |
-
|
| 1765 |
-
|
| 1766 |
-
|
| 1767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1768 |
|
| 1769 |
|
| 1770 |
# ===========================
|
|
@@ -1801,6 +1884,64 @@ def combine_region_masks(canvas_results, canvas_size):
|
|
| 1801 |
|
| 1802 |
return combined_mask
|
| 1803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1804 |
# ===========================
|
| 1805 |
# INITIALIZE SYSTEM AND API
|
| 1806 |
# ===========================
|
|
@@ -1866,7 +2007,7 @@ with st.sidebar:
|
|
| 1866 |
st.caption("For faster processing, use a GPU-enabled environment")
|
| 1867 |
|
| 1868 |
st.markdown("---")
|
| 1869 |
-
st.markdown("###
|
| 1870 |
st.markdown("""
|
| 1871 |
- **Style Transfer**: Apply artistic styles to images
|
| 1872 |
- **Regional Transform**: Paint areas for local effects
|
|
@@ -2136,7 +2277,7 @@ with tab2:
|
|
| 2136 |
# Region management
|
| 2137 |
col_btn1, col_btn2, col_btn3 = st.columns(3)
|
| 2138 |
with col_btn1:
|
| 2139 |
-
if st.button("
|
| 2140 |
new_region = {
|
| 2141 |
'id': len(st.session_state.regions),
|
| 2142 |
'style': style_choices[0] if style_choices else None,
|
|
@@ -2184,7 +2325,7 @@ with tab2:
|
|
| 2184 |
key=f"region_intensity_{i}"
|
| 2185 |
)
|
| 2186 |
|
| 2187 |
-
if st.button(f"
|
| 2188 |
# Remove the region
|
| 2189 |
st.session_state.regions.pop(i)
|
| 2190 |
|
|
@@ -2216,9 +2357,9 @@ with tab2:
|
|
| 2216 |
# Show workflow status
|
| 2217 |
if 'regional_result' in st.session_state:
|
| 2218 |
if st.session_state.canvas_ready:
|
| 2219 |
-
st.success("
|
| 2220 |
else:
|
| 2221 |
-
st.info("
|
| 2222 |
else:
|
| 2223 |
st.info("Paint on the canvas below to define regions for each style")
|
| 2224 |
|
|
@@ -2246,7 +2387,6 @@ with tab2:
|
|
| 2246 |
|
| 2247 |
current_region = st.session_state.regions[current_region_idx]
|
| 2248 |
|
| 2249 |
-
# THIS IS THE FIX: The following line was added.
|
| 2250 |
col_draw1, col_draw2, col_draw3 = st.columns(3)
|
| 2251 |
|
| 2252 |
with col_draw1:
|
|
@@ -2347,10 +2487,23 @@ with tab2:
|
|
| 2347 |
|
| 2348 |
st.session_state['regional_result'] = result
|
| 2349 |
|
| 2350 |
-
# Show result
|
| 2351 |
if 'regional_result' in st.session_state:
|
| 2352 |
st.subheader("Result")
|
| 2353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2354 |
|
| 2355 |
# Download button
|
| 2356 |
buf = io.BytesIO()
|
|
@@ -2361,10 +2514,7 @@ with tab2:
|
|
| 2361 |
file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
|
| 2362 |
mime="image/png"
|
| 2363 |
)
|
| 2364 |
-
|
| 2365 |
-
# TAB 3: Video Processing
|
| 2366 |
-
# TAB 3: Video Processing
|
| 2367 |
-
# TAB 3: Video Processing
|
| 2368 |
# TAB 3: Video Processing
|
| 2369 |
with tab3:
|
| 2370 |
st.header("Video Processing")
|
|
@@ -2555,7 +2705,7 @@ with tab3:
|
|
| 2555 |
|
| 2556 |
|
| 2557 |
|
| 2558 |
-
# TAB 4: Training with AdaIN
|
| 2559 |
with tab4:
|
| 2560 |
st.header("Train Custom Style with AdaIN")
|
| 2561 |
st.markdown("Train your own style transfer model using Adaptive Instance Normalization")
|
|
@@ -2563,6 +2713,10 @@ with tab4:
|
|
| 2563 |
# Initialize session state for content images
|
| 2564 |
if 'content_images_list' not in st.session_state:
|
| 2565 |
st.session_state.content_images_list = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2566 |
|
| 2567 |
col1, col2, col3 = st.columns([1, 1, 1])
|
| 2568 |
|
|
@@ -2603,7 +2757,7 @@ with tab4:
|
|
| 2603 |
st.caption(f"... and {len(content_imgs) - 3} more")
|
| 2604 |
|
| 2605 |
with col3:
|
| 2606 |
-
st.subheader("
|
| 2607 |
|
| 2608 |
model_name = st.text_input("Model Name",
|
| 2609 |
value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
@@ -2664,18 +2818,24 @@ with tab4:
|
|
| 2664 |
st.session_state['trained_adain_model'] = model
|
| 2665 |
st.session_state['trained_style_images'] = style_images
|
| 2666 |
st.session_state['model_path'] = f'/tmp/trained_models/{model_name}_final.pth'
|
| 2667 |
-
st.success("AdaIN training complete
|
| 2668 |
|
| 2669 |
progress_bar.empty()
|
| 2670 |
status_text.empty()
|
| 2671 |
else:
|
| 2672 |
st.error("Please upload both style and content images")
|
| 2673 |
|
| 2674 |
-
# Testing section
|
| 2675 |
if 'trained_adain_model' in st.session_state:
|
| 2676 |
st.markdown("---")
|
| 2677 |
st.header("Test Your AdaIN Model")
|
| 2678 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2679 |
test_col1, test_col2, test_col3 = st.columns([1, 1, 1])
|
| 2680 |
|
| 2681 |
with test_col1:
|
|
@@ -2683,7 +2843,7 @@ with tab4:
|
|
| 2683 |
|
| 2684 |
# Test image selection
|
| 2685 |
test_source = st.radio("Test Image Source",
|
| 2686 |
-
["Use Content Image", "Upload New"],
|
| 2687 |
horizontal=True)
|
| 2688 |
|
| 2689 |
test_image = None
|
|
@@ -2693,6 +2853,13 @@ with tab4:
|
|
| 2693 |
range(len(st.session_state.content_images_list)),
|
| 2694 |
format_func=lambda x: f"Content Image {x+1}")
|
| 2695 |
test_image = Image.open(st.session_state.content_images_list[content_idx]).convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2696 |
else:
|
| 2697 |
# Upload new image
|
| 2698 |
test_upload = st.file_uploader("Upload test image",
|
|
@@ -2701,6 +2868,10 @@ with tab4:
|
|
| 2701 |
if test_upload:
|
| 2702 |
test_image = Image.open(test_upload).convert('RGB')
|
| 2703 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2704 |
# Style selection for testing
|
| 2705 |
if 'trained_style_images' in st.session_state and len(st.session_state['trained_style_images']) > 1:
|
| 2706 |
style_idx = st.selectbox("Select style",
|
|
@@ -2716,36 +2887,125 @@ with tab4:
|
|
| 2716 |
# Alpha blending control
|
| 2717 |
alpha = st.slider("Style Strength (Alpha)", 0.0, 1.0, 1.0, 0.1,
|
| 2718 |
help="0 = original content, 1 = full style transfer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2719 |
|
| 2720 |
with test_col2:
|
| 2721 |
-
st.subheader("Original")
|
| 2722 |
-
|
| 2723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2724 |
if test_style:
|
|
|
|
| 2725 |
st.image(test_style, caption="Style Image", use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2726 |
|
| 2727 |
with test_col3:
|
| 2728 |
st.subheader("Result")
|
| 2729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2730 |
with st.spinner("Applying style..."):
|
| 2731 |
-
|
| 2732 |
-
|
| 2733 |
-
|
| 2734 |
-
|
| 2735 |
-
|
| 2736 |
-
|
| 2737 |
-
|
| 2738 |
-
|
| 2739 |
-
|
| 2740 |
-
|
| 2741 |
-
|
| 2742 |
-
result
|
| 2743 |
-
|
| 2744 |
-
|
| 2745 |
-
|
| 2746 |
-
|
| 2747 |
-
|
|
|
|
|
|
|
| 2748 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2749 |
|
| 2750 |
# Model download section
|
| 2751 |
st.markdown("---")
|
|
@@ -2756,13 +3016,73 @@ with tab4:
|
|
| 2756 |
st.download_button(
|
| 2757 |
label="Download Trained AdaIN Model",
|
| 2758 |
data=f.read(),
|
| 2759 |
-
file_name=f"{model_name}
|
| 2760 |
mime="application/octet-stream",
|
| 2761 |
use_container_width=True
|
| 2762 |
)
|
| 2763 |
with col_dl2:
|
| 2764 |
st.info("This model can be loaded and used for real-time style transfer")
|
| 2765 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2766 |
# TAB 5: Batch Processing
|
| 2767 |
with tab5:
|
| 2768 |
st.header("Batch Processing")
|
|
@@ -2915,7 +3235,7 @@ with tab6:
|
|
| 2915 |
- Supports all style combinations and blend modes
|
| 2916 |
- Enhanced codec compatibility
|
| 2917 |
|
| 2918 |
-
####
|
| 2919 |
- Train on any artistic style with minimal data (1-50 images)
|
| 2920 |
- Automatic data augmentation for small datasets
|
| 2921 |
- Adjustable model complexity (3-12 residual blocks)
|
|
@@ -2976,4 +3296,4 @@ with tab6:
|
|
| 2976 |
|
| 2977 |
# Footer
|
| 2978 |
st.markdown("---")
|
| 2979 |
-
st.markdown("Style transfer system with
|
|
|
|
| 46 |
# Set page config
|
| 47 |
st.set_page_config(
|
| 48 |
page_title="Style Transfer Studio",
|
| 49 |
+
# page_icon="",
|
| 50 |
layout="wide",
|
| 51 |
initial_sidebar_state="expanded"
|
| 52 |
)
|
|
|
|
| 1344 |
return model
|
| 1345 |
except:
|
| 1346 |
return None
|
| 1347 |
+
# Inside the StyleTransferSystem class, add these methods:
|
| 1348 |
+
|
| 1349 |
+
def _create_linear_weight(self, width, height, overlap):
|
| 1350 |
+
"""Create linear blending weights for tile edges"""
|
| 1351 |
+
weight = np.ones((height, width, 1), dtype=np.float32)
|
| 1352 |
+
|
| 1353 |
+
if overlap > 0:
|
| 1354 |
+
# Create gradients for each edge
|
| 1355 |
+
for i in range(overlap):
|
| 1356 |
+
alpha = i / overlap
|
| 1357 |
+
# Top edge
|
| 1358 |
+
weight[i, :] *= alpha
|
| 1359 |
+
# Bottom edge
|
| 1360 |
+
weight[-i-1, :] *= alpha
|
| 1361 |
+
# Left edge
|
| 1362 |
+
weight[:, i] *= alpha
|
| 1363 |
+
# Right edge
|
| 1364 |
+
weight[:, -i-1] *= alpha
|
| 1365 |
+
|
| 1366 |
+
return weight
|
| 1367 |
+
|
| 1368 |
+
def _create_gaussian_weight(self, width, height, overlap):
|
| 1369 |
+
"""Create Gaussian blending weights for smoother transitions"""
|
| 1370 |
+
weight = np.ones((height, width), dtype=np.float32)
|
| 1371 |
+
|
| 1372 |
+
# Create 2D Gaussian centered in the tile
|
| 1373 |
+
y, x = np.ogrid[:height, :width]
|
| 1374 |
+
center_y, center_x = height / 2, width / 2
|
| 1375 |
+
|
| 1376 |
+
# Distance from center
|
| 1377 |
+
dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
|
| 1378 |
+
|
| 1379 |
+
# Gaussian falloff starting from the edges
|
| 1380 |
+
max_dist = min(height, width) / 2
|
| 1381 |
+
sigma = max_dist / 2 # Adjust for smoother/sharper transitions
|
| 1382 |
+
|
| 1383 |
+
# Apply Gaussian only near edges
|
| 1384 |
+
edge_dist = np.minimum(
|
| 1385 |
+
np.minimum(y, height - 1 - y),
|
| 1386 |
+
np.minimum(x, width - 1 - x)
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
# Weight is 1 in center, Gaussian falloff near edges
|
| 1390 |
+
weight = np.where(
|
| 1391 |
+
edge_dist < overlap,
|
| 1392 |
+
np.exp(-0.5 * ((overlap - edge_dist) / (overlap/3))**2),
|
| 1393 |
+
1.0
|
| 1394 |
+
)
|
| 1395 |
+
|
| 1396 |
+
return weight.reshape(height, width, 1)
|
| 1397 |
|
| 1398 |
def apply_lightweight_style(self, image, model, intensity=1.0):
|
| 1399 |
"""Apply style using a lightweight model"""
|
|
|
|
| 1704 |
|
| 1705 |
return model
|
| 1706 |
|
| 1707 |
+
def apply_adain_style(self, content_image, style_image, model, alpha=1.0, use_tiling=False):
|
| 1708 |
+
"""Apply AdaIN-based style transfer with optional tiling"""
|
| 1709 |
+
if use_tiling and (content_image.width > 512 or content_image.height > 512):
|
| 1710 |
+
# Use tiling for large images
|
| 1711 |
+
return self.apply_adain_style_tiled(
|
| 1712 |
+
content_image, style_image, model, alpha,
|
| 1713 |
+
tile_size=256, # Match training size
|
| 1714 |
+
overlap=32,
|
| 1715 |
+
blend_mode='gaussian'
|
| 1716 |
+
)
|
| 1717 |
+
|
| 1718 |
+
# Original implementation for small images
|
| 1719 |
if content_image is None or style_image is None or model is None:
|
| 1720 |
return None
|
| 1721 |
|
|
|
|
| 1725 |
|
| 1726 |
original_size = content_image.size
|
| 1727 |
|
| 1728 |
+
# Transform for AdaIN (VGG normalization)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1729 |
transform = transforms.Compose([
|
| 1730 |
+
transforms.Resize((256, 256)), # Direct resize, no cropping
|
| 1731 |
transforms.ToTensor(),
|
| 1732 |
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1733 |
std=[0.229, 0.224, 0.225])
|
|
|
|
| 1737 |
style_tensor = transform(style_image).unsqueeze(0).to(self.device)
|
| 1738 |
|
| 1739 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1740 |
output = model(content_tensor, style_tensor, alpha=alpha)
|
| 1741 |
|
| 1742 |
+
# Denormalize
|
| 1743 |
output = output.squeeze(0).cpu()
|
| 1744 |
+
output = output * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 1745 |
+
output = output + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1746 |
output = torch.clamp(output, 0, 1)
|
| 1747 |
|
| 1748 |
+
# Convert to PIL
|
| 1749 |
output_img = transforms.ToPILImage()(output)
|
| 1750 |
+
output_img = output_img.resize(original_size, Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1751 |
|
| 1752 |
return output_img
|
| 1753 |
|
| 1754 |
+
except Exception as e:
|
| 1755 |
+
print(f"Error applying AdaIN style: {e}")
|
| 1756 |
+
traceback.print_exc()
|
| 1757 |
+
return None
|
| 1758 |
+
|
| 1759 |
+
def apply_adain_style_tiled(self, content_image, style_image, model, alpha=1.0,
|
| 1760 |
+
tile_size=256, overlap=32, blend_mode='linear'):
|
| 1761 |
+
"""
|
| 1762 |
+
Apply AdaIN style transfer using tiling for high-quality results.
|
| 1763 |
+
Processes image in overlapping tiles to maintain quality.
|
| 1764 |
+
"""
|
| 1765 |
+
if content_image is None or style_image is None or model is None:
|
| 1766 |
+
return None
|
| 1767 |
+
|
| 1768 |
+
try:
|
| 1769 |
+
model = model.to(self.device)
|
| 1770 |
+
model.eval()
|
| 1771 |
+
|
| 1772 |
+
# Prepare transforms
|
| 1773 |
+
transform = transforms.Compose([
|
| 1774 |
+
transforms.Resize((tile_size, tile_size)),
|
| 1775 |
+
transforms.ToTensor(),
|
| 1776 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 1777 |
+
std=[0.229, 0.224, 0.225])
|
| 1778 |
+
])
|
| 1779 |
+
|
| 1780 |
+
# Process style image once (at tile size)
|
| 1781 |
+
style_tensor = transform(style_image).unsqueeze(0).to(self.device)
|
| 1782 |
+
|
| 1783 |
+
# Get dimensions
|
| 1784 |
+
w, h = content_image.size
|
| 1785 |
+
|
| 1786 |
+
# Calculate tile positions with overlap
|
| 1787 |
+
stride = tile_size - overlap
|
| 1788 |
+
tiles_x = list(range(0, w - tile_size + 1, stride))
|
| 1789 |
+
tiles_y = list(range(0, h - tile_size + 1, stride))
|
| 1790 |
+
|
| 1791 |
+
# Ensure we cover the entire image
|
| 1792 |
+
if tiles_x[-1] + tile_size < w:
|
| 1793 |
+
tiles_x.append(w - tile_size)
|
| 1794 |
+
if tiles_y[-1] + tile_size < h:
|
| 1795 |
+
tiles_y.append(h - tile_size)
|
| 1796 |
+
|
| 1797 |
+
# If image is smaller than tile size, just process normally
|
| 1798 |
+
if w <= tile_size and h <= tile_size:
|
| 1799 |
+
return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False)
|
| 1800 |
+
|
| 1801 |
+
print(f"Processing {len(tiles_x) * len(tiles_y)} tiles of size {tile_size}x{tile_size}")
|
| 1802 |
+
|
| 1803 |
+
# Initialize output and weight arrays
|
| 1804 |
+
output_array = np.zeros((h, w, 3), dtype=np.float32)
|
| 1805 |
+
weight_array = np.zeros((h, w, 1), dtype=np.float32)
|
| 1806 |
+
|
| 1807 |
+
# Process each tile
|
| 1808 |
+
with torch.no_grad():
|
| 1809 |
+
for y_idx, y in enumerate(tiles_y):
|
| 1810 |
+
for x_idx, x in enumerate(tiles_x):
|
| 1811 |
+
# Extract tile
|
| 1812 |
+
tile = content_image.crop((x, y, x + tile_size, y + tile_size))
|
| 1813 |
+
|
| 1814 |
+
# Transform tile
|
| 1815 |
+
tile_tensor = transform(tile).unsqueeze(0).to(self.device)
|
| 1816 |
+
|
| 1817 |
+
# Apply AdaIN to tile
|
| 1818 |
+
styled_tensor = model(tile_tensor, style_tensor, alpha=alpha)
|
| 1819 |
+
|
| 1820 |
+
# Denormalize
|
| 1821 |
+
styled_tensor = styled_tensor.squeeze(0).cpu()
|
| 1822 |
+
denorm_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 1823 |
+
denorm_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 1824 |
+
styled_tensor = styled_tensor * denorm_std + denorm_mean
|
| 1825 |
+
styled_tensor = torch.clamp(styled_tensor, 0, 1)
|
| 1826 |
+
|
| 1827 |
+
# Convert to numpy
|
| 1828 |
+
styled_tile = styled_tensor.permute(1, 2, 0).numpy() * 255
|
| 1829 |
+
|
| 1830 |
+
# Create weight mask for blending
|
| 1831 |
+
if blend_mode == 'gaussian':
|
| 1832 |
+
weight = self._create_gaussian_weight(tile_size, tile_size, overlap)
|
| 1833 |
+
else:
|
| 1834 |
+
weight = self._create_linear_weight(tile_size, tile_size, overlap)
|
| 1835 |
+
|
| 1836 |
+
# Add to output with weights
|
| 1837 |
+
output_array[y:y+tile_size, x:x+tile_size] += styled_tile * weight
|
| 1838 |
+
weight_array[y:y+tile_size, x:x+tile_size] += weight
|
| 1839 |
+
|
| 1840 |
+
# Normalize by weights
|
| 1841 |
+
output_array = output_array / (weight_array + 1e-8)
|
| 1842 |
+
output_array = np.clip(output_array, 0, 255).astype(np.uint8)
|
| 1843 |
+
|
| 1844 |
+
return Image.fromarray(output_array)
|
| 1845 |
+
|
| 1846 |
+
except Exception as e:
|
| 1847 |
+
print(f"Error in tiled AdaIN processing: {e}")
|
| 1848 |
+
traceback.print_exc()
|
| 1849 |
+
# Fallback to standard processing
|
| 1850 |
+
return self.apply_adain_style(content_image, style_image, model, alpha, use_tiling=False)
|
| 1851 |
|
| 1852 |
|
| 1853 |
# ===========================
|
|
|
|
| 1884 |
|
| 1885 |
return combined_mask
|
| 1886 |
|
| 1887 |
+
def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False):
|
| 1888 |
+
"""Apply AdaIN style transfer to a painted region only"""
|
| 1889 |
+
if content_image is None or style_image is None or model is None:
|
| 1890 |
+
return None
|
| 1891 |
+
|
| 1892 |
+
try:
|
| 1893 |
+
# Get the mask from canvas
|
| 1894 |
+
if canvas_result is None or canvas_result.image_data is None:
|
| 1895 |
+
# No mask painted, apply to whole image
|
| 1896 |
+
return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
|
| 1897 |
+
|
| 1898 |
+
# Extract mask from canvas
|
| 1899 |
+
mask_data = canvas_result.image_data[:, :, 3] # Alpha channel
|
| 1900 |
+
mask = mask_data > 0
|
| 1901 |
+
|
| 1902 |
+
# Resize mask to match original image size
|
| 1903 |
+
original_size = content_image.size
|
| 1904 |
+
display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0])
|
| 1905 |
+
|
| 1906 |
+
if original_size != display_size:
|
| 1907 |
+
# Convert mask to PIL image for resizing
|
| 1908 |
+
mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
|
| 1909 |
+
mask_pil = mask_pil.resize(original_size, Image.NEAREST)
|
| 1910 |
+
mask = np.array(mask_pil) > 128
|
| 1911 |
+
|
| 1912 |
+
# Apply feathering to mask edges if requested
|
| 1913 |
+
if feather_radius > 0:
|
| 1914 |
+
from scipy.ndimage import gaussian_filter
|
| 1915 |
+
mask_float = mask.astype(np.float32)
|
| 1916 |
+
mask_float = gaussian_filter(mask_float, sigma=feather_radius)
|
| 1917 |
+
mask_float = np.clip(mask_float, 0, 1)
|
| 1918 |
+
else:
|
| 1919 |
+
mask_float = mask.astype(np.float32)
|
| 1920 |
+
|
| 1921 |
+
# Apply style to entire image with tiling option
|
| 1922 |
+
styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
|
| 1923 |
+
|
| 1924 |
+
if styled_full is None:
|
| 1925 |
+
return None
|
| 1926 |
+
|
| 1927 |
+
# Blend original and styled based on mask
|
| 1928 |
+
original_array = np.array(content_image, dtype=np.float32)
|
| 1929 |
+
styled_array = np.array(styled_full, dtype=np.float32)
|
| 1930 |
+
|
| 1931 |
+
# Expand mask to 3 channels
|
| 1932 |
+
mask_3ch = np.stack([mask_float] * 3, axis=2)
|
| 1933 |
+
|
| 1934 |
+
# Blend
|
| 1935 |
+
result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch
|
| 1936 |
+
result_array = np.clip(result_array, 0, 255).astype(np.uint8)
|
| 1937 |
+
|
| 1938 |
+
return Image.fromarray(result_array)
|
| 1939 |
+
|
| 1940 |
+
except Exception as e:
|
| 1941 |
+
print(f"Error applying regional AdaIN style: {e}")
|
| 1942 |
+
traceback.print_exc()
|
| 1943 |
+
return None
|
| 1944 |
+
|
| 1945 |
# ===========================
|
| 1946 |
# INITIALIZE SYSTEM AND API
|
| 1947 |
# ===========================
|
|
|
|
| 2007 |
st.caption("For faster processing, use a GPU-enabled environment")
|
| 2008 |
|
| 2009 |
st.markdown("---")
|
| 2010 |
+
st.markdown("### Quick Guide")
|
| 2011 |
st.markdown("""
|
| 2012 |
- **Style Transfer**: Apply artistic styles to images
|
| 2013 |
- **Regional Transform**: Paint areas for local effects
|
|
|
|
| 2277 |
# Region management
|
| 2278 |
col_btn1, col_btn2, col_btn3 = st.columns(3)
|
| 2279 |
with col_btn1:
|
| 2280 |
+
if st.button("Add Region", use_container_width=True):
|
| 2281 |
new_region = {
|
| 2282 |
'id': len(st.session_state.regions),
|
| 2283 |
'style': style_choices[0] if style_choices else None,
|
|
|
|
| 2325 |
key=f"region_intensity_{i}"
|
| 2326 |
)
|
| 2327 |
|
| 2328 |
+
if st.button(f"Remove Region {i+1}", key=f"remove_region_{i}"):
|
| 2329 |
# Remove the region
|
| 2330 |
st.session_state.regions.pop(i)
|
| 2331 |
|
|
|
|
| 2357 |
# Show workflow status
|
| 2358 |
if 'regional_result' in st.session_state:
|
| 2359 |
if st.session_state.canvas_ready:
|
| 2360 |
+
st.success("Edit Mode - Paint your regions and click 'Apply Regional Styles' when ready")
|
| 2361 |
else:
|
| 2362 |
+
st.info("Preview Mode - Click 'Continue Editing' to modify regions")
|
| 2363 |
else:
|
| 2364 |
st.info("Paint on the canvas below to define regions for each style")
|
| 2365 |
|
|
|
|
| 2387 |
|
| 2388 |
current_region = st.session_state.regions[current_region_idx]
|
| 2389 |
|
|
|
|
| 2390 |
col_draw1, col_draw2, col_draw3 = st.columns(3)
|
| 2391 |
|
| 2392 |
with col_draw1:
|
|
|
|
| 2487 |
|
| 2488 |
st.session_state['regional_result'] = result
|
| 2489 |
|
| 2490 |
+
# Show result with fixed size
|
| 2491 |
if 'regional_result' in st.session_state:
|
| 2492 |
st.subheader("Result")
|
| 2493 |
+
|
| 2494 |
+
# Add display size control
|
| 2495 |
+
display_size = st.slider("Display Size", 300, 800, 600, 50, key="regional_display_size")
|
| 2496 |
+
|
| 2497 |
+
# Fixed size display
|
| 2498 |
+
result_display = resize_image_for_display(
|
| 2499 |
+
st.session_state['regional_result'],
|
| 2500 |
+
max_width=display_size,
|
| 2501 |
+
max_height=display_size
|
| 2502 |
+
)
|
| 2503 |
+
st.image(result_display, caption="Regional Styled Image")
|
| 2504 |
+
|
| 2505 |
+
# Show actual dimensions
|
| 2506 |
+
st.caption(f"Original size: {st.session_state['regional_result'].size[0]}x{st.session_state['regional_result'].size[1]} pixels")
|
| 2507 |
|
| 2508 |
# Download button
|
| 2509 |
buf = io.BytesIO()
|
|
|
|
| 2514 |
file_name=f"regional_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
|
| 2515 |
mime="image/png"
|
| 2516 |
)
|
| 2517 |
+
|
|
|
|
|
|
|
|
|
|
| 2518 |
# TAB 3: Video Processing
|
| 2519 |
with tab3:
|
| 2520 |
st.header("Video Processing")
|
|
|
|
| 2705 |
|
| 2706 |
|
| 2707 |
|
| 2708 |
+
# TAB 4: Training with AdaIN and Regional Application
|
| 2709 |
with tab4:
|
| 2710 |
st.header("Train Custom Style with AdaIN")
|
| 2711 |
st.markdown("Train your own style transfer model using Adaptive Instance Normalization")
|
|
|
|
| 2713 |
# Initialize session state for content images
|
| 2714 |
if 'content_images_list' not in st.session_state:
|
| 2715 |
st.session_state.content_images_list = []
|
| 2716 |
+
if 'adain_canvas_result' not in st.session_state:
|
| 2717 |
+
st.session_state.adain_canvas_result = None
|
| 2718 |
+
if 'adain_test_image' not in st.session_state:
|
| 2719 |
+
st.session_state.adain_test_image = None
|
| 2720 |
|
| 2721 |
col1, col2, col3 = st.columns([1, 1, 1])
|
| 2722 |
|
|
|
|
| 2757 |
st.caption(f"... and {len(content_imgs) - 3} more")
|
| 2758 |
|
| 2759 |
with col3:
|
| 2760 |
+
st.subheader("Training Settings")
|
| 2761 |
|
| 2762 |
model_name = st.text_input("Model Name",
|
| 2763 |
value=f"adain_style_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
|
|
| 2818 |
st.session_state['trained_adain_model'] = model
|
| 2819 |
st.session_state['trained_style_images'] = style_images
|
| 2820 |
st.session_state['model_path'] = f'/tmp/trained_models/{model_name}_final.pth'
|
| 2821 |
+
st.success("AdaIN training complete")
|
| 2822 |
|
| 2823 |
progress_bar.empty()
|
| 2824 |
status_text.empty()
|
| 2825 |
else:
|
| 2826 |
st.error("Please upload both style and content images")
|
| 2827 |
|
| 2828 |
+
# Testing section with regional application
|
| 2829 |
if 'trained_adain_model' in st.session_state:
|
| 2830 |
st.markdown("---")
|
| 2831 |
st.header("Test Your AdaIN Model")
|
| 2832 |
|
| 2833 |
+
# Application mode selection
|
| 2834 |
+
application_mode = st.radio("Application Mode",
|
| 2835 |
+
["Whole Image", "Paint Region"],
|
| 2836 |
+
horizontal=True,
|
| 2837 |
+
help="Choose whether to apply style to entire image or paint specific regions")
|
| 2838 |
+
|
| 2839 |
test_col1, test_col2, test_col3 = st.columns([1, 1, 1])
|
| 2840 |
|
| 2841 |
with test_col1:
|
|
|
|
| 2843 |
|
| 2844 |
# Test image selection
|
| 2845 |
test_source = st.radio("Test Image Source",
|
| 2846 |
+
["Use Content Image", "Upload New", "Use Unsplash Image"],
|
| 2847 |
horizontal=True)
|
| 2848 |
|
| 2849 |
test_image = None
|
|
|
|
| 2853 |
range(len(st.session_state.content_images_list)),
|
| 2854 |
format_func=lambda x: f"Content Image {x+1}")
|
| 2855 |
test_image = Image.open(st.session_state.content_images_list[content_idx]).convert('RGB')
|
| 2856 |
+
elif test_source == "Use Unsplash Image":
|
| 2857 |
+
# Use current Unsplash image if available
|
| 2858 |
+
if 'current_image' in st.session_state and st.session_state.get('image_source', '').startswith('Unsplash'):
|
| 2859 |
+
test_image = st.session_state['current_image']
|
| 2860 |
+
st.success("Using Unsplash image")
|
| 2861 |
+
else:
|
| 2862 |
+
st.info("Please search and select an image from the Style Transfer tab first")
|
| 2863 |
else:
|
| 2864 |
# Upload new image
|
| 2865 |
test_upload = st.file_uploader("Upload test image",
|
|
|
|
| 2868 |
if test_upload:
|
| 2869 |
test_image = Image.open(test_upload).convert('RGB')
|
| 2870 |
|
| 2871 |
+
# Store test image in session state
|
| 2872 |
+
if test_image:
|
| 2873 |
+
st.session_state['adain_test_image'] = test_image
|
| 2874 |
+
|
| 2875 |
# Style selection for testing
|
| 2876 |
if 'trained_style_images' in st.session_state and len(st.session_state['trained_style_images']) > 1:
|
| 2877 |
style_idx = st.selectbox("Select style",
|
|
|
|
| 2887 |
# Alpha blending control
|
| 2888 |
alpha = st.slider("Style Strength (Alpha)", 0.0, 1.0, 1.0, 0.1,
|
| 2889 |
help="0 = original content, 1 = full style transfer")
|
| 2890 |
+
|
| 2891 |
+
# Add tiling option
|
| 2892 |
+
use_tiling = st.checkbox("🔲 Use Tiled Processing",
|
| 2893 |
+
value=True,
|
| 2894 |
+
help="Process large images in tiles for better quality. Recommended for images larger than 512x512.")
|
| 2895 |
+
|
| 2896 |
+
# Regional painting options (only show if in paint mode)
|
| 2897 |
+
if application_mode == "🖌️ Paint Region":
|
| 2898 |
+
st.markdown("---")
|
| 2899 |
+
st.subheader("🖌️ Painting Options")
|
| 2900 |
+
|
| 2901 |
+
brush_size = st.slider("Brush Size", 5, 100, 30)
|
| 2902 |
+
drawing_mode = st.selectbox("Drawing Tool",
|
| 2903 |
+
["freedraw", "line", "rect", "circle", "polygon"],
|
| 2904 |
+
index=0)
|
| 2905 |
+
|
| 2906 |
+
# Feather/blur the mask edges
|
| 2907 |
+
feather_radius = st.slider("Edge Softness", 0, 50, 10,
|
| 2908 |
+
help="Blur mask edges for smoother transitions")
|
| 2909 |
+
|
| 2910 |
+
col_btn1, col_btn2 = st.columns(2)
|
| 2911 |
+
with col_btn1:
|
| 2912 |
+
if st.button("Clear Canvas", use_container_width=True):
|
| 2913 |
+
st.session_state['adain_canvas_result'] = None
|
| 2914 |
+
st.rerun()
|
| 2915 |
+
|
| 2916 |
+
with col_btn2:
|
| 2917 |
+
if st.button("Reset Result", use_container_width=True):
|
| 2918 |
+
if 'adain_styled_result' in st.session_state:
|
| 2919 |
+
del st.session_state['adain_styled_result']
|
| 2920 |
+
st.rerun()
|
| 2921 |
|
| 2922 |
with test_col2:
|
| 2923 |
+
st.subheader("Canvas / Original")
|
| 2924 |
+
|
| 2925 |
+
if application_mode == "Paint Region" and test_image:
|
| 2926 |
+
# Show canvas for painting
|
| 2927 |
+
display_img = resize_image_for_display(test_image, max_width=400, max_height=400)
|
| 2928 |
+
canvas_width, canvas_height = display_img.size
|
| 2929 |
+
|
| 2930 |
+
st.info("🖌️ Paint the areas where you want to apply the style")
|
| 2931 |
+
|
| 2932 |
+
# Canvas for painting mask
|
| 2933 |
+
canvas_result = st_canvas(
|
| 2934 |
+
fill_color="rgba(255, 0, 0, 0.3)", # Red with transparency
|
| 2935 |
+
stroke_width=brush_size,
|
| 2936 |
+
stroke_color="rgba(255, 0, 0, 0.5)",
|
| 2937 |
+
background_image=display_img,
|
| 2938 |
+
update_streamlit=True,
|
| 2939 |
+
height=canvas_height,
|
| 2940 |
+
width=canvas_width,
|
| 2941 |
+
drawing_mode=drawing_mode,
|
| 2942 |
+
display_toolbar=True,
|
| 2943 |
+
key=f"adain_canvas_{brush_size}_{drawing_mode}"
|
| 2944 |
+
)
|
| 2945 |
+
|
| 2946 |
+
# Save canvas result
|
| 2947 |
+
if canvas_result:
|
| 2948 |
+
st.session_state['adain_canvas_result'] = canvas_result
|
| 2949 |
+
|
| 2950 |
+
# Show style image below canvas
|
| 2951 |
if test_style:
|
| 2952 |
+
st.markdown("---")
|
| 2953 |
st.image(test_style, caption="Style Image", use_column_width=True)
|
| 2954 |
+
|
| 2955 |
+
else:
|
| 2956 |
+
# Show original images
|
| 2957 |
+
if test_image:
|
| 2958 |
+
st.image(test_image, caption="Content Image", use_column_width=True)
|
| 2959 |
+
if test_style:
|
| 2960 |
+
st.image(test_style, caption="Style Image", use_column_width=True)
|
| 2961 |
|
| 2962 |
with test_col3:
|
| 2963 |
st.subheader("Result")
|
| 2964 |
+
|
| 2965 |
+
# Apply button
|
| 2966 |
+
apply_button = st.button("Apply Style", type="primary", use_container_width=True)
|
| 2967 |
+
|
| 2968 |
+
if apply_button and test_image and test_style:
|
| 2969 |
with st.spinner("Applying style..."):
|
| 2970 |
+
if application_mode == "🖼️ Whole Image":
|
| 2971 |
+
# Apply to whole image
|
| 2972 |
+
result = system.apply_adain_style(
|
| 2973 |
+
test_image,
|
| 2974 |
+
test_style,
|
| 2975 |
+
st.session_state['trained_adain_model'],
|
| 2976 |
+
alpha=alpha,
|
| 2977 |
+
use_tiling=use_tiling
|
| 2978 |
+
)
|
| 2979 |
+
else:
|
| 2980 |
+
# Apply to painted region
|
| 2981 |
+
result = apply_adain_regional(
|
| 2982 |
+
test_image,
|
| 2983 |
+
test_style,
|
| 2984 |
+
st.session_state['trained_adain_model'],
|
| 2985 |
+
st.session_state.get('adain_canvas_result'),
|
| 2986 |
+
alpha=alpha,
|
| 2987 |
+
feather_radius=feather_radius,
|
| 2988 |
+
use_tiling=use_tiling
|
| 2989 |
)
|
| 2990 |
+
|
| 2991 |
+
if result:
|
| 2992 |
+
st.session_state['adain_styled_result'] = result
|
| 2993 |
+
|
| 2994 |
+
# Show result if available
|
| 2995 |
+
if 'adain_styled_result' in st.session_state:
|
| 2996 |
+
st.image(st.session_state['adain_styled_result'],
|
| 2997 |
+
caption="Styled Result",
|
| 2998 |
+
use_column_width=True)
|
| 2999 |
+
|
| 3000 |
+
# Download button
|
| 3001 |
+
buf = io.BytesIO()
|
| 3002 |
+
st.session_state['adain_styled_result'].save(buf, format='PNG')
|
| 3003 |
+
st.download_button(
|
| 3004 |
+
label="Download Result",
|
| 3005 |
+
data=buf.getvalue(),
|
| 3006 |
+
file_name=f"adain_styled_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
|
| 3007 |
+
mime="image/png"
|
| 3008 |
+
)
|
| 3009 |
|
| 3010 |
# Model download section
|
| 3011 |
st.markdown("---")
|
|
|
|
| 3016 |
st.download_button(
|
| 3017 |
label="Download Trained AdaIN Model",
|
| 3018 |
data=f.read(),
|
| 3019 |
+
file_name=f"{st.session_state.get('model_name', 'adain')}_final.pth",
|
| 3020 |
mime="application/octet-stream",
|
| 3021 |
use_container_width=True
|
| 3022 |
)
|
| 3023 |
with col_dl2:
|
| 3024 |
st.info("This model can be loaded and used for real-time style transfer")
|
| 3025 |
|
| 3026 |
+
|
| 3027 |
+
# Add this helper function (place it before the tab or with other helper functions)
|
| 3028 |
+
def apply_adain_regional(content_image, style_image, model, canvas_result, alpha=1.0, feather_radius=10, use_tiling=False):
|
| 3029 |
+
"""Apply AdaIN style transfer to a painted region only"""
|
| 3030 |
+
if content_image is None or style_image is None or model is None:
|
| 3031 |
+
return None
|
| 3032 |
+
|
| 3033 |
+
try:
|
| 3034 |
+
# Get the mask from canvas
|
| 3035 |
+
if canvas_result is None or canvas_result.image_data is None:
|
| 3036 |
+
# No mask painted, apply to whole image
|
| 3037 |
+
return system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
|
| 3038 |
+
|
| 3039 |
+
# Extract mask from canvas
|
| 3040 |
+
mask_data = canvas_result.image_data[:, :, 3] # Alpha channel
|
| 3041 |
+
mask = mask_data > 0
|
| 3042 |
+
|
| 3043 |
+
# Resize mask to match original image size
|
| 3044 |
+
original_size = content_image.size
|
| 3045 |
+
display_size = (canvas_result.image_data.shape[1], canvas_result.image_data.shape[0])
|
| 3046 |
+
|
| 3047 |
+
if original_size != display_size:
|
| 3048 |
+
# Convert mask to PIL image for resizing
|
| 3049 |
+
mask_pil = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
|
| 3050 |
+
mask_pil = mask_pil.resize(original_size, Image.NEAREST)
|
| 3051 |
+
mask = np.array(mask_pil) > 128
|
| 3052 |
+
|
| 3053 |
+
# Apply feathering to mask edges if requested
|
| 3054 |
+
if feather_radius > 0:
|
| 3055 |
+
from scipy.ndimage import gaussian_filter
|
| 3056 |
+
mask_float = mask.astype(np.float32)
|
| 3057 |
+
mask_float = gaussian_filter(mask_float, sigma=feather_radius)
|
| 3058 |
+
mask_float = np.clip(mask_float, 0, 1)
|
| 3059 |
+
else:
|
| 3060 |
+
mask_float = mask.astype(np.float32)
|
| 3061 |
+
|
| 3062 |
+
# Apply style to entire image with tiling option
|
| 3063 |
+
styled_full = system.apply_adain_style(content_image, style_image, model, alpha, use_tiling=use_tiling)
|
| 3064 |
+
|
| 3065 |
+
if styled_full is None:
|
| 3066 |
+
return None
|
| 3067 |
+
|
| 3068 |
+
# Blend original and styled based on mask
|
| 3069 |
+
original_array = np.array(content_image, dtype=np.float32)
|
| 3070 |
+
styled_array = np.array(styled_full, dtype=np.float32)
|
| 3071 |
+
|
| 3072 |
+
# Expand mask to 3 channels
|
| 3073 |
+
mask_3ch = np.stack([mask_float] * 3, axis=2)
|
| 3074 |
+
|
| 3075 |
+
# Blend
|
| 3076 |
+
result_array = original_array * (1 - mask_3ch) + styled_array * mask_3ch
|
| 3077 |
+
result_array = np.clip(result_array, 0, 255).astype(np.uint8)
|
| 3078 |
+
|
| 3079 |
+
return Image.fromarray(result_array)
|
| 3080 |
+
|
| 3081 |
+
except Exception as e:
|
| 3082 |
+
print(f"Error applying regional AdaIN style: {e}")
|
| 3083 |
+
traceback.print_exc()
|
| 3084 |
+
return None
|
| 3085 |
+
|
| 3086 |
# TAB 5: Batch Processing
|
| 3087 |
with tab5:
|
| 3088 |
st.header("Batch Processing")
|
|
|
|
| 3235 |
- Supports all style combinations and blend modes
|
| 3236 |
- Enhanced codec compatibility
|
| 3237 |
|
| 3238 |
+
#### Custom Training
|
| 3239 |
- Train on any artistic style with minimal data (1-50 images)
|
| 3240 |
- Automatic data augmentation for small datasets
|
| 3241 |
- Adjustable model complexity (3-12 residual blocks)
|
|
|
|
| 3296 |
|
| 3297 |
# Footer
|
| 3298 |
st.markdown("---")
|
| 3299 |
+
st.markdown("Style transfer system with CycleGAN models and regional painting capabilities.")
|