StyleTransfer / app.py
alexscottcodes
Add app files.
7f57474
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
import requests
import base64
import io
# CPU optimization: Disable CUDA and use optimized CPU threads
torch.set_num_threads(4) # Adjust based on your CPU
device = torch.device("cpu")
# Get QuickCloud API URL from environment variable
QUICKCLOUD_API_URL = os.environ.get("QUICKCLOUD_API_URL", "")
class LightingStyleTransfer:
def __init__(self):
# Use VGG16 for feature extraction (lighter than VGG19)
vgg = models.vgg16(pretrained=True).features.to(device).eval()
# Freeze parameters for CPU efficiency
for param in vgg.parameters():
param.requires_grad = False
self.model = vgg
# Layer indices for content and style
self.style_layers = [0, 5, 10, 17] # Reduced layers for CPU
self.content_layers = [17]
def preprocess(self, img, max_size=512):
"""Resize and normalize image - smaller size for CPU"""
# CPU optimization: Use smaller image size
w, h = img.size
scale = max_size / max(w, h)
new_size = (int(w * scale), int(h * scale))
img = img.resize(new_size, Image.LANCZOS)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(img).unsqueeze(0).to(device)
def deprocess(self, tensor):
"""Convert tensor back to image"""
img = tensor.cpu().clone().squeeze(0)
img = img.clamp(0, 1)
img = transforms.ToPILImage()(img)
return img
def gram_matrix(self, tensor):
"""Compute Gram matrix for style representation"""
b, c, h, w = tensor.size()
features = tensor.view(b * c, h * w)
G = torch.mm(features, features.t())
return G.div(b * c * h * w)
def get_features(self, image):
"""Extract features from specified layers"""
features = {}
x = image
for idx, layer in enumerate(self.model):
x = layer(x)
if idx in self.style_layers:
features[f'style_{idx}'] = x
if idx in self.content_layers:
features[f'content_{idx}'] = x
return features
def transfer(self, content_img, style_img, steps=150, style_weight=1e6,
content_weight=1):
"""Perform lighting style transfer"""
# Preprocess images
content = self.preprocess(content_img)
style = self.preprocess(style_img)
# Initialize target as content image
target = content.clone().requires_grad_(True)
# Get features
content_features = self.get_features(content)
style_features = self.get_features(style)
# Compute style gram matrices
style_grams = {k: self.gram_matrix(v) for k, v in style_features.items()
if 'style' in k}
# CPU optimization: Use LBFGS optimizer (faster convergence)
optimizer = torch.optim.LBFGS([target], max_iter=20)
step = [0]
def closure():
target.data.clamp_(0, 1)
optimizer.zero_grad()
target_features = self.get_features(target)
# Content loss
content_loss = 0
for k in content_features:
if 'content' in k:
content_loss += torch.mean((target_features[k] -
content_features[k]) ** 2)
# Style loss
style_loss = 0
for k in style_grams:
target_gram = self.gram_matrix(target_features[k])
style_loss += torch.mean((target_gram - style_grams[k]) ** 2)
# Total loss
total_loss = content_weight * content_loss + style_weight * style_loss
total_loss.backward()
step[0] += 1
if step[0] % 30 == 0:
print(f"Step {step[0]}, Loss: {total_loss.item():.2f}")
return total_loss
# Optimization loop
epochs = steps // 20 # LBFGS takes ~20 iterations per step
for i in range(epochs):
optimizer.step(closure)
if step[0] >= steps:
break
# Final clamp and return
target.data.clamp_(0, 1)
return self.deprocess(target)
def process_with_quickcloud(content_img, style_img, steps, style_strength):
"""Process using QuickCloud API (powered by Modal.com)"""
if not QUICKCLOUD_API_URL:
return None, "❌ QuickCloud API URL not configured. Please set QUICKCLOUD_API_URL environment variable."
try:
# Convert PIL images to bytes
content_bytes = io.BytesIO()
style_bytes = io.BytesIO()
content_img.save(content_bytes, format='PNG')
style_img.save(style_bytes, format='PNG')
# Encode to base64
content_b64 = base64.b64encode(content_bytes.getvalue()).decode()
style_b64 = base64.b64encode(style_bytes.getvalue()).decode()
# Prepare request
payload = {
"content_image": content_b64,
"style_image": style_b64,
"steps": steps,
"style_weight": style_strength * 1e6,
"content_weight": 1.0,
"learning_rate": 0.03
}
print("Sending request to NamelessAI QuickCloud (H100 GPU)...")
# Make API request
response = requests.post(QUICKCLOUD_API_URL, json=payload, timeout=300)
response.raise_for_status()
# Decode result
result_data = response.json()
result_bytes = base64.b64decode(result_data["result_image"])
result_img = Image.open(io.BytesIO(result_bytes))
return result_img, "✅ Processing complete via QuickCloud (H100 GPU)!"
except requests.exceptions.Timeout:
return None, "❌ Request timed out. Please try again."
except requests.exceptions.RequestException as e:
return None, f"❌ API Error: {str(e)}"
except Exception as e:
return None, f"❌ Error: {str(e)}"
def process_locally(content_img, style_img, steps, style_strength):
"""Process using local CPU"""
try:
# Adjust style weight
style_weight = style_strength * 1e6
# Perform transfer
result = style_transfer.transfer(
content_img,
style_img,
steps=steps,
style_weight=style_weight,
content_weight=1
)
return result, "✅ Processing complete via Local CPU!"
except Exception as e:
return None, f"❌ Error: {str(e)}"
def process_images(content_img, style_img, steps, style_strength, use_quickcloud):
"""Process the style transfer based on selected mode"""
if content_img is None or style_img is None:
return None, "⚠️ Please upload both content and style images."
# Convert to PIL if needed
if isinstance(content_img, np.ndarray):
content_img = Image.fromarray(content_img)
if isinstance(style_img, np.ndarray):
style_img = Image.fromarray(style_img)
if use_quickcloud:
return process_with_quickcloud(content_img, style_img, steps, style_strength)
else:
return process_locally(content_img, style_img, steps, style_strength)
# Initialize local model (done once at startup)
print("Loading local model... This may take a moment.")
style_transfer = LightingStyleTransfer()
print("Local model loaded successfully!")
# Check if QuickCloud is available
quickcloud_available = bool(QUICKCLOUD_API_URL)
if quickcloud_available:
print(f"✓ QuickCloud API configured and available")
else:
print("✗ QuickCloud API not configured (set QUICKCLOUD_API_URL environment variable)")
# Create Gradio interface
with gr.Blocks(title="AI Lighting Style Transfer") as demo:
gr.Markdown("""
# 🎨 AI-Powered Lighting Style Transfer
Transfer the lighting and color style from one image to another using neural style transfer.
## Processing Options:
- **Local (CPU)**: Runs on your machine. Takes 1-3 minutes. Free.
- **NamelessAI QuickCloud**: Runs on H100 GPU cloud. Takes 5-10 seconds. Requires API key.
- *Powered by Modal.com*
## How to use:
1. Upload your **content image** (the image you want to transform)
2. Upload your **style image** (the image whose lighting you want to copy)
3. Choose processing mode (Local or QuickCloud)
4. Adjust settings if desired
5. Click "Transfer Style" and wait for processing
""")
with gr.Row():
with gr.Column():
content_input = gr.Image(label="Content Image", type="pil")
style_input = gr.Image(label="Style Image", type="pil")
with gr.Column():
output = gr.Image(label="Result")
status_text = gr.Textbox(label="Status", interactive=False)
with gr.Row():
use_quickcloud = gr.Checkbox(
label="Use NamelessAI QuickCloud (H100 GPU - Powered by Modal.com)",
value=False,
interactive=quickcloud_available,
info="5-10 seconds vs 1-3 minutes locally" if quickcloud_available else "API URL not configured"
)
with gr.Row():
steps_slider = gr.Slider(
minimum=50,
maximum=300,
value=150,
step=10,
label="Optimization Steps (more = better quality, slower)"
)
style_strength = gr.Slider(
minimum=0.5,
maximum=3.0,
value=1.0,
step=0.1,
label="Style Strength"
)
transfer_btn = gr.Button("Transfer Style", variant="primary", size="lg")
gr.Markdown("""
### Tips:
- **Local Mode**: Images resized to 512px, use 100-150 steps for balance
- **QuickCloud Mode**: Handles 1024px images, 300 steps recommended for best quality
- Increase style strength for more dramatic lighting effects
- Works best with images that have distinct lighting patterns
### QuickCloud Setup:
To use QuickCloud, set the `QUICKCLOUD_API_URL` environment variable to your Modal API endpoint.
""")
# Set up the button click
transfer_btn.click(
fn=process_images,
inputs=[content_input, style_input, steps_slider, style_strength, use_quickcloud],
outputs=[output, status_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch()