MonetStyle / app.py
TamerTokgoz's picture
Upload 6 files
0dd207d verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
# =========================================================
# DEVICE
# =========================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =========================================================
# MODEL ARCHITECTURE
# =========================================================
class UNetDown(nn.Module):
def __init__(
self,
in_channels,
out_channels,
normalize=True,
dropout=0.0
):
super().__init__()
layers = [
nn.Conv2d(
in_channels,
out_channels,
4,
2,
1,
bias=False
)
]
if normalize:
layers.append(
nn.InstanceNorm2d(out_channels)
)
layers.append(
nn.LeakyReLU(0.2)
)
if dropout:
layers.append(
nn.Dropout(dropout)
)
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUp(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0
):
super().__init__()
layers = [
nn.ConvTranspose2d(
in_channels,
out_channels,
4,
2,
1,
bias=False
),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True)
]
if dropout:
layers.append(
nn.Dropout(dropout)
)
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
class Generator(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3
):
super().__init__()
self.down1 = UNetDown(in_channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512, dropout=0.5)
self.down5 = UNetDown(512, 512, dropout=0.5)
self.down6 = UNetDown(512, 512, dropout=0.5)
self.down7 = UNetDown(
512,
512,
normalize=False,
dropout=0.5
)
self.up1 = UNetUp(512, 512, dropout=0.5)
self.up2 = UNetUp(1024, 512, dropout=0.5)
self.up3 = UNetUp(1024, 512, dropout=0.5)
self.up4 = UNetUp(1024, 256)
self.up5 = UNetUp(512, 128)
self.up6 = UNetUp(256, 64)
self.final = nn.Sequential(
nn.ConvTranspose2d(
128,
out_channels,
4,
2,
1
),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
u1 = self.up1(d7, d6)
u2 = self.up2(u1, d5)
u3 = self.up3(u2, d4)
u4 = self.up4(u3, d3)
u5 = self.up5(u4, d2)
u6 = self.up6(u5, d1)
return self.final(u6)
# =========================================================
# LOAD MODEL
# =========================================================
generator = Generator().to(device)
generator.load_state_dict(
torch.load(
"generator_model.pth",
map_location=device
)
)
generator.eval()
# =========================================================
# STYLE TRANSFER FUNCTION
# =========================================================
def monet_style_transfer(
input_image,
style_strength,
image_quality
):
if input_image is None:
return None, "⚠️ Please upload an image."
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)
)
])
input_image = input_image.convert("RGB")
img_tensor = transform(input_image).unsqueeze(0).to(device)
with torch.no_grad():
styled = generator(img_tensor)
# =====================================================
# STYLE STRENGTH CONTROL
# =====================================================
styled = styled * (style_strength / 100)
styled = styled * 0.5 + 0.5
styled = styled.squeeze(0).cpu().permute(1, 2, 0).numpy()
styled = (styled * 255).clip(0, 255).astype("uint8")
output_image = Image.fromarray(styled)
result_text = f"""
# 🎨 Monet Style Transformation Complete
### Style Strength: {style_strength}%
### Image Quality: {image_quality}
### AI artistic rendering successfully generated.
"""
return output_image, result_text
# =========================================================
# CUSTOM CSS
# =========================================================
custom_css = """
body {
background: #f5f7fb;
font-family: 'Segoe UI', sans-serif;
}
.gradio-container {
max-width: 1300px !important;
margin: auto;
}
.hero {
background: linear-gradient(135deg,#111827,#7c3aed);
padding: 40px;
border-radius: 30px;
color: white;
margin-bottom: 20px;
}
.hero h1 {
font-size: 52px;
font-weight: 800;
margin-bottom: 10px;
}
.hero p {
font-size: 18px;
opacity: 0.92;
}
.card {
background: white;
border-radius: 24px;
padding: 22px;
box-shadow: 0 6px 18px rgba(0,0,0,0.08);
}
button {
height: 60px !important;
border-radius: 18px !important;
border: none !important;
background: linear-gradient(135deg,#7c3aed,#6d28d9) !important;
color: white !important;
font-size: 20px !important;
font-weight: 700 !important;
}
button:hover {
background: linear-gradient(135deg,#6d28d9,#5b21b6) !important;
}
input, textarea, select {
border-radius: 16px !important;
}
@media (max-width: 768px){
.hero {
padding: 22px;
}
.hero h1 {
font-size: 32px;
}
.hero p {
font-size: 15px;
}
button {
height: 54px !important;
font-size: 17px !important;
}
}
"""
# =========================================================
# HERO HTML
# =========================================================
hero_html = """
<div class="hero">
<h1>🎨 AI Monet Style Transfer</h1>
<p>
Transform your photos into Monet-inspired paintings using deep learning.
Modern mobile-friendly interface with smart controls and AI-powered rendering.
</p>
</div>
"""
# =========================================================
# INTERFACE
# =========================================================
with gr.Blocks(
css=custom_css,
theme=gr.themes.Soft(
primary_hue="violet",
secondary_hue="purple"
)
) as demo:
gr.HTML(hero_html)
# =====================================================
# TOP INFO CARDS
# =====================================================
with gr.Row():
with gr.Column():
gr.Markdown("""
### ⚡ Fast AI Rendering
Generate artistic images instantly
""")
with gr.Column():
gr.Markdown("""
### 📱 Mobile Responsive
Optimized for all screen sizes
""")
with gr.Column():
gr.Markdown("""
### 🧠 Deep Learning
GAN-based artistic transformation
""")
# =====================================================
# MAIN AREA
# =====================================================
with gr.Row():
# =================================================
# LEFT PANEL
# =================================================
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="🖼️ Upload Photo"
)
style_strength = gr.Slider(
minimum=10,
maximum=100,
value=80,
step=5,
label="🎨 Style Strength"
)
image_quality = gr.Dropdown(
choices=[
"Standard",
"High",
"Ultra"
],
value="High",
label="✨ Image Quality"
)
transform_button = gr.Button(
"🚀 Generate Monet Art",
variant="primary"
)
# =================================================
# RIGHT PANEL
# =================================================
with gr.Column(scale=1):
output_image = gr.Image(
type="pil",
label="🖌️ Monet Styled Image"
)
result_text = gr.Markdown()
# =====================================================
# EXAMPLES
# =====================================================
gr.Examples(
examples=[],
inputs=[input_image]
)
# =====================================================
# BUTTON ACTION
# =====================================================
transform_button.click(
fn=monet_style_transfer,
inputs=[
input_image,
style_strength,
image_quality
],
outputs=[
output_image,
result_text
]
)
# =========================================================
# LAUNCH
# =========================================================
demo.launch()