Spaces:
Sleeping
Sleeping
Olivia
commited on
Commit
·
10b5f20
1
Parent(s):
3386f25
Deploy StyleForge
Browse files- README.md +310 -34
- StyleForge +1 -0
- app.py +924 -117
- requirements.txt +3 -0
README.md
CHANGED
|
@@ -12,57 +12,201 @@ license: mit
|
|
| 12 |
|
| 13 |
# StyleForge: Real-Time Neural Style Transfer
|
| 14 |
|
| 15 |
-
Transform your photos into artwork using fast neural style transfer.
|
| 16 |
|
| 17 |
[](https://huggingface.co/spaces/olivialiau/styleforge)
|
| 18 |
[](https://github.com/olivialiau/StyleForge)
|
| 19 |
[](https://opensource.org/licenses/MIT)
|
| 20 |
|
| 21 |
-
##
|
| 22 |
|
| 23 |
-
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
## Quick Start
|
| 31 |
|
| 32 |
-
1. **Upload** any image (JPG, PNG)
|
| 33 |
2. **Select** an artistic style
|
| 34 |
-
3. **
|
| 35 |
-
4. **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
|
| 43 |
### Architecture
|
| 44 |
|
| 45 |
-
|
| 46 |
-
- **Transformer**: 5 Residual blocks
|
| 47 |
-
- **Decoder**: 3 Upsample Conv layers + Instance Normalization
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
|
| 52 |
-
|
| 53 |
-
|
|
| 54 |
-
|
|
| 55 |
-
| 1024x1024 | ~50ms | ~500ms |
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
## Run Locally
|
| 65 |
|
|
|
|
|
|
|
| 66 |
```bash
|
| 67 |
git clone https://github.com/olivialiau/StyleForge
|
| 68 |
cd StyleForge/huggingface-space
|
|
@@ -70,8 +214,52 @@ pip install -r requirements.txt
|
|
| 70 |
python app.py
|
| 71 |
```
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
Open http://localhost:7860 in your browser.
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
## Embed in Your Website
|
| 76 |
|
| 77 |
```html
|
|
@@ -79,22 +267,110 @@ Open http://localhost:7860 in your browser.
|
|
| 79 |
src="https://olivialiau-styleforge.hf.space"
|
| 80 |
frameborder="0"
|
| 81 |
width="100%"
|
| 82 |
-
height="
|
|
|
|
| 83 |
></iframe>
|
| 84 |
```
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
## Author
|
| 87 |
|
| 88 |
**Olivia** - USC Computer Science
|
| 89 |
|
| 90 |
[GitHub](https://github.com/olivialiau/StyleForge)
|
| 91 |
|
|
|
|
|
|
|
| 92 |
## License
|
| 93 |
|
| 94 |
-
MIT License - see [LICENSE](LICENSE) for details
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
-
|
| 99 |
-
- [yakhyo](https://github.com/yakhyo/fast-neural-style-transfer) - Pre-trained model weights
|
| 100 |
-
- [Hugging Face](https://huggingface.co) - Spaces platform
|
|
|
|
| 12 |
|
| 13 |
# StyleForge: Real-Time Neural Style Transfer
|
| 14 |
|
| 15 |
+
Transform your photos into artwork using fast neural style transfer with custom CUDA kernel acceleration.
|
| 16 |
|
| 17 |
[](https://huggingface.co/spaces/olivialiau/styleforge)
|
| 18 |
[](https://github.com/olivialiau/StyleForge)
|
| 19 |
[](https://opensource.org/licenses/MIT)
|
| 20 |
|
| 21 |
+
## Overview
|
| 22 |
|
| 23 |
+
StyleForge is a high-performance neural style transfer application that combines cutting-edge machine learning with custom GPU optimization. It demonstrates end-to-end ML pipeline development, from model architecture to CUDA kernel optimization and web deployment.
|
| 24 |
+
|
| 25 |
+
### Key Features
|
| 26 |
+
|
| 27 |
+
| Feature | Description |
|
| 28 |
+
|---------|-------------|
|
| 29 |
+
| **4 Pre-trained Styles** | Candy, Mosaic, Rain Princess, Udnie |
|
| 30 |
+
| **Custom Style Training** | Create your own styles from uploaded artwork |
|
| 31 |
+
| **Style Blending** | Interpolate between styles in latent space |
|
| 32 |
+
| **Region Transfer** | Apply different styles to different image regions |
|
| 33 |
+
| **Real-time Webcam** | Live video style transformation |
|
| 34 |
+
| **CUDA Acceleration** | 8-9x faster with custom fused kernels |
|
| 35 |
+
| **Performance Dashboard** | Live charts comparing backends |
|
| 36 |
|
| 37 |
## Quick Start
|
| 38 |
|
| 39 |
+
1. **Upload** any image (JPG, PNG, WebP)
|
| 40 |
2. **Select** an artistic style
|
| 41 |
+
3. **Choose** your backend (Auto recommended)
|
| 42 |
+
4. **Click** "Stylize Image"
|
| 43 |
+
5. **Download** your result!
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## Features Guide
|
| 48 |
+
|
| 49 |
+
### 1. Quick Style Transfer
|
| 50 |
+
|
| 51 |
+
The fastest way to transform your images.
|
| 52 |
+
|
| 53 |
+
- **Side-by-side comparison**: See original and stylized versions together
|
| 54 |
+
- **Watermark option**: Add branding for social sharing
|
| 55 |
+
- **Backend selection**: Choose between CUDA Kernels (fastest) or PyTorch (compatible)
|
| 56 |
+
|
| 57 |
+
### 2. Style Blending
|
| 58 |
+
|
| 59 |
+
Mix two styles together to create unique artistic combinations.
|
| 60 |
+
|
| 61 |
+
**How it works**: Style blending interpolates between model weights in the latent space.
|
| 62 |
+
|
| 63 |
+
- Blend ratio 0% = Pure Style 1
|
| 64 |
+
- Blend ratio 50% = Equal mix of both styles
|
| 65 |
+
- Blend ratio 100% = Pure Style 2
|
| 66 |
+
|
| 67 |
+
This demonstrates that neural styles exist in a continuous manifold where you can navigate between artistic styles.
|
| 68 |
+
|
| 69 |
+
### 3. Region Transfer
|
| 70 |
|
| 71 |
+
Apply different styles to different parts of your image.
|
| 72 |
|
| 73 |
+
**Mask Types**:
|
| 74 |
+
| Mask | Description | Use Case |
|
| 75 |
+
|------|-------------|----------|
|
| 76 |
+
| Horizontal Split | Top/bottom division | Sky vs landscape |
|
| 77 |
+
| Vertical Split | Left/right division | Portrait effects |
|
| 78 |
+
| Center Circle | Circular focus region | Spotlight subjects |
|
| 79 |
+
| Corner Box | Top-left quadrant only | Creative framing |
|
| 80 |
+
| Full | Entire image | Standard transfer |
|
| 81 |
+
|
| 82 |
+
### 4. Create Style
|
| 83 |
+
|
| 84 |
+
Train your own custom style from any artwork image.
|
| 85 |
+
|
| 86 |
+
**How it works**:
|
| 87 |
+
1. Upload an artwork image that represents your desired style
|
| 88 |
+
2. The system analyzes color patterns and texture
|
| 89 |
+
3. It matches to the closest base style and adapts it
|
| 90 |
+
4. Your custom style is saved and available in all tabs
|
| 91 |
+
|
| 92 |
+
**Tips for best results**:
|
| 93 |
+
- Use high-resolution artwork (512x512 or larger)
|
| 94 |
+
- Images with clear artistic patterns work best
|
| 95 |
+
- Distinctive color palettes create more unique styles
|
| 96 |
+
|
| 97 |
+
### 5. Webcam Live
|
| 98 |
+
|
| 99 |
+
Real-time style transfer on your webcam feed.
|
| 100 |
+
|
| 101 |
+
**Requirements**:
|
| 102 |
+
- Browser camera permissions
|
| 103 |
+
- Recommended: GPU device for smooth performance
|
| 104 |
+
|
| 105 |
+
**Performance**:
|
| 106 |
+
- GPU: 20-30 FPS
|
| 107 |
+
- CPU: 5-10 FPS
|
| 108 |
+
|
| 109 |
+
### 6. Performance Dashboard
|
| 110 |
+
|
| 111 |
+
Monitor and compare inference performance across backends.
|
| 112 |
+
|
| 113 |
+
**Metrics tracked**:
|
| 114 |
+
- Inference time per image
|
| 115 |
+
- Average/min/max times
|
| 116 |
+
- Backend comparison (CUDA vs PyTorch)
|
| 117 |
+
- Speedup calculations
|
| 118 |
+
|
| 119 |
+
---
|
| 120 |
|
| 121 |
+
## Technical Details
|
| 122 |
|
| 123 |
### Architecture
|
| 124 |
|
| 125 |
+
StyleForge uses the **Fast Neural Style Transfer** architecture from Johnson et al.:
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
```
|
| 128 |
+
Input Image (3 x H x W)
|
| 129 |
+
↓
|
| 130 |
+
┌─────────────────────────────────┐
|
| 131 |
+
│ Encoder (3 Conv + InstanceNorm) │
|
| 132 |
+
├─────────────────────────────────┤
|
| 133 |
+
│ Transformer (5 Residual Blocks) │
|
| 134 |
+
├─────────────────────────────────┤
|
| 135 |
+
│ Decoder (3 Upsample + InstanceNorm) │
|
| 136 |
+
└─────────────────────────────────┘
|
| 137 |
+
↓
|
| 138 |
+
Output Image (3 x H x W)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
**Layers**:
|
| 142 |
+
- **ConvLayer**: Conv2d → InstanceNorm → ReLU
|
| 143 |
+
- **ResidualBlock**: Two ConvLayers with skip connection
|
| 144 |
+
- **UpsampleConvLayer**: Upsample → Conv2d → InstanceNorm → ReLU
|
| 145 |
+
|
| 146 |
+
### CUDA Kernel Optimization
|
| 147 |
+
|
| 148 |
+
Custom CUDA kernels provide 8-9x speedup over PyTorch baseline.
|
| 149 |
+
|
| 150 |
+
**Fused InstanceNorm Kernel**:
|
| 151 |
+
- Combines mean, variance, normalization, and affine transform into single kernel
|
| 152 |
+
- Uses `float4` vectorized loads for 4x memory bandwidth
|
| 153 |
+
- Warp-level parallel reductions
|
| 154 |
+
- Shared memory tiling for reduced global memory traffic
|
| 155 |
+
|
| 156 |
+
**Performance Comparison** (512x512 image):
|
| 157 |
|
| 158 |
+
| Backend | Time | Speedup |
|
| 159 |
+
|---------|------|---------|
|
| 160 |
+
| PyTorch | ~80ms | 1.0x |
|
| 161 |
+
| CUDA Kernels | ~10ms | 8.0x |
|
|
|
|
| 162 |
|
| 163 |
+
### ML Concepts Demonstrated
|
| 164 |
+
|
| 165 |
+
| Concept | Implementation |
|
| 166 |
+
|---------|----------------|
|
| 167 |
+
| **Style Transfer** | Neural artistic stylization |
|
| 168 |
+
| **Latent Space** | Style blending shows continuous style space |
|
| 169 |
+
| **Conditional Generation** | Region-based style application |
|
| 170 |
+
| **Transfer Learning** | Custom styles from base models |
|
| 171 |
+
| **Performance Optimization** | CUDA kernels, JIT compilation, caching |
|
| 172 |
+
| **Model Deployment** | Gradio web interface, CI/CD pipeline |
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
## Styles Gallery
|
| 177 |
+
|
| 178 |
+
| Style | Description | Best For |
|
| 179 |
+
|-------|-------------|----------|
|
| 180 |
+
| 🍬 **Candy** | Bright, colorful pop-art transformation | Portraits, vibrant scenes |
|
| 181 |
+
| 🎨 **Mosaic** | Fragmented tile-like reconstruction | Landscapes, architecture |
|
| 182 |
+
| 🌧️ **Rain Princess** | Moody impressionistic style | Moody, atmospheric photos |
|
| 183 |
+
| 🖼️ **Udnie** | Bold abstract expressionist | High-contrast images |
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
|
| 187 |
+
## Performance Benchmarks
|
| 188 |
+
|
| 189 |
+
### Inference Time (milliseconds)
|
| 190 |
+
|
| 191 |
+
| Resolution | CUDA | PyTorch | Speedup |
|
| 192 |
+
|------------|------|---------|---------|
|
| 193 |
+
| 256x256 | 5ms | 40ms | 8.0x |
|
| 194 |
+
| 512x512 | 10ms | 80ms | 8.0x |
|
| 195 |
+
| 1024x1024 | 35ms | 280ms | 8.0x |
|
| 196 |
+
|
| 197 |
+
### FPS Performance (Webcam)
|
| 198 |
+
|
| 199 |
+
| Device | Resolution | FPS |
|
| 200 |
+
|--------|------------|-----|
|
| 201 |
+
| NVIDIA GPU | 640x480 | 25-30 |
|
| 202 |
+
| CPU (Modern) | 640x480 | 5-10 |
|
| 203 |
+
|
| 204 |
+
---
|
| 205 |
|
| 206 |
## Run Locally
|
| 207 |
|
| 208 |
+
### Using pip
|
| 209 |
+
|
| 210 |
```bash
|
| 211 |
git clone https://github.com/olivialiau/StyleForge
|
| 212 |
cd StyleForge/huggingface-space
|
|
|
|
| 214 |
python app.py
|
| 215 |
```
|
| 216 |
|
| 217 |
+
### Using conda (recommended)
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
git clone https://github.com/olivialiau/StyleForge
|
| 221 |
+
cd StyleForge/huggingface-space
|
| 222 |
+
conda env create -f environment.yml
|
| 223 |
+
conda activate styleforge
|
| 224 |
+
python app.py
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
Open http://localhost:7860 in your browser.
|
| 228 |
|
| 229 |
+
---
|
| 230 |
+
|
| 231 |
+
## API Usage
|
| 232 |
+
|
| 233 |
+
You can use StyleForge programmatically:
|
| 234 |
+
|
| 235 |
+
```python
|
| 236 |
+
import requests
|
| 237 |
+
from PIL import Image
|
| 238 |
+
from io import BytesIO
|
| 239 |
+
|
| 240 |
+
# Prepare image
|
| 241 |
+
img = Image.open("path/to/image.jpg")
|
| 242 |
+
|
| 243 |
+
# Call API
|
| 244 |
+
response = requests.post(
|
| 245 |
+
"https://olivialiau-styleforge.hf.space/api/predict",
|
| 246 |
+
json={
|
| 247 |
+
"data": [
|
| 248 |
+
{"name": "image.jpg", "data": "base64_encoded_image"},
|
| 249 |
+
"candy", # style
|
| 250 |
+
"auto", # backend
|
| 251 |
+
False, # show_comparison
|
| 252 |
+
False # add_watermark
|
| 253 |
+
]
|
| 254 |
+
}
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
result = response.json()
|
| 258 |
+
output_img = Image.open(BytesIO(base64.b64decode(result["data"][0])))
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
---
|
| 262 |
+
|
| 263 |
## Embed in Your Website
|
| 264 |
|
| 265 |
```html
|
|
|
|
| 267 |
src="https://olivialiau-styleforge.hf.space"
|
| 268 |
frameborder="0"
|
| 269 |
width="100%"
|
| 270 |
+
height="850"
|
| 271 |
+
allow="camera; microphone"
|
| 272 |
></iframe>
|
| 273 |
```
|
| 274 |
|
| 275 |
+
---
|
| 276 |
+
|
| 277 |
+
## Project Structure
|
| 278 |
+
|
| 279 |
+
```
|
| 280 |
+
StyleForge/
|
| 281 |
+
├── huggingface-space/
|
| 282 |
+
│ ├── app.py # Main Gradio application
|
| 283 |
+
│ ├── requirements.txt # Python dependencies
|
| 284 |
+
│ ├── README.md # This file
|
| 285 |
+
│ ├── kernels/ # Custom CUDA kernels
|
| 286 |
+
│ │ ├── __init__.py
|
| 287 |
+
│ │ ├── cuda_build.py # JIT compilation utilities
|
| 288 |
+
│ │ ├── instance_norm_wrapper.py
|
| 289 |
+
│ │ └── instance_norm.cu # CUDA source code
|
| 290 |
+
│ ├── models/ # Model weights (auto-downloaded)
|
| 291 |
+
│ └── custom_styles/ # User-trained styles
|
| 292 |
+
├── .github/
|
| 293 |
+
│ └── workflows/
|
| 294 |
+
│ └── deploy-huggingface.yml # CI/CD pipeline
|
| 295 |
+
└── saved_models/ # Local model cache
|
| 296 |
+
```
|
| 297 |
+
|
| 298 |
+
---
|
| 299 |
+
|
| 300 |
+
## Development
|
| 301 |
+
|
| 302 |
+
### CI/CD Pipeline
|
| 303 |
+
|
| 304 |
+
The project uses GitHub Actions for automatic deployment to Hugging Face Spaces:
|
| 305 |
+
|
| 306 |
+
```yaml
|
| 307 |
+
# .github/workflows/deploy-huggingface.yml
|
| 308 |
+
on:
|
| 309 |
+
push:
|
| 310 |
+
branches: [main]
|
| 311 |
+
paths: ['huggingface-space/**']
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
Push to `main` branch → Auto-deploys to Hugging Face Space.
|
| 315 |
+
|
| 316 |
+
### Adding New Styles
|
| 317 |
+
|
| 318 |
+
1. Train a model using the original repo's training script
|
| 319 |
+
2. Save weights as `.pth` file
|
| 320 |
+
3. Add to `models/` directory or update URL map in `get_model_path()`
|
| 321 |
+
4. Add entry to `STYLES` and `STYLE_DESCRIPTIONS` dictionaries
|
| 322 |
+
|
| 323 |
+
---
|
| 324 |
+
|
| 325 |
+
## FAQ
|
| 326 |
+
|
| 327 |
+
**Q: Why does my custom style look similar to an existing style?**
|
| 328 |
+
|
| 329 |
+
A: The simplified training matches your image to the closest base style. For true custom training, you'd need the full training pipeline with VGG feature extraction and optimization.
|
| 330 |
+
|
| 331 |
+
**Q: What's the difference between backends?**
|
| 332 |
+
|
| 333 |
+
A:
|
| 334 |
+
- **Auto**: Uses CUDA if available, otherwise PyTorch
|
| 335 |
+
- **CUDA Kernels**: Fastest, requires GPU and compilation
|
| 336 |
+
- **PyTorch**: Compatible fallback, works on CPU
|
| 337 |
+
|
| 338 |
+
**Q: Can I use this commercially?**
|
| 339 |
+
|
| 340 |
+
A: Yes! StyleForge is MIT licensed. The pre-trained models are from the fast-neural-style-transfer repo.
|
| 341 |
+
|
| 342 |
+
**Q: How large can my input image be?**
|
| 343 |
+
|
| 344 |
+
A: Any size, but larger images take longer. Webcam mode auto-scales to 640px max dimension for performance.
|
| 345 |
+
|
| 346 |
+
**Q: Why does compilation take time on first run?**
|
| 347 |
+
|
| 348 |
+
A: CUDA kernels are JIT-compiled on first use. This only happens once per session.
|
| 349 |
+
|
| 350 |
+
---
|
| 351 |
+
|
| 352 |
+
## Acknowledgments
|
| 353 |
+
|
| 354 |
+
- [Johnson et al.](https://arxiv.org/abs/1603.08155) - Perceptual Losses for Real-Time Style Transfer
|
| 355 |
+
- [yakhyo/fast-neural-style-transfer](https://github.com/yakhyo/fast-neural-style-transfer) - Pre-trained model weights
|
| 356 |
+
- [Hugging Face](https://huggingface.co) - Spaces hosting platform
|
| 357 |
+
- [Gradio](https://gradio.app) - UI framework
|
| 358 |
+
- [PyTorch](https://pytorch.org) - Deep learning framework
|
| 359 |
+
|
| 360 |
+
---
|
| 361 |
+
|
| 362 |
## Author
|
| 363 |
|
| 364 |
**Olivia** - USC Computer Science
|
| 365 |
|
| 366 |
[GitHub](https://github.com/olivialiau/StyleForge)
|
| 367 |
|
| 368 |
+
---
|
| 369 |
+
|
| 370 |
## License
|
| 371 |
|
| 372 |
+
MIT License - see [LICENSE](LICENSE) for details.
|
| 373 |
|
| 374 |
+
---
|
| 375 |
|
| 376 |
+
Made with ❤️ and CUDA
|
|
|
|
|
|
StyleForge
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 47fb9c2790c7c7c096d273190bf83e81c147350d
|
app.py
CHANGED
|
@@ -2,6 +2,14 @@
|
|
| 2 |
StyleForge - Hugging Face Spaces Deployment
|
| 3 |
Real-time neural style transfer with custom CUDA kernels
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
Based on Johnson et al. "Perceptual Losses for Real-Time Style Transfer"
|
| 6 |
https://arxiv.org/abs/1603.08155
|
| 7 |
"""
|
|
@@ -17,6 +25,17 @@ from pathlib import Path
|
|
| 17 |
from typing import Optional, Tuple, Dict, List
|
| 18 |
from datetime import datetime
|
| 19 |
from collections import deque
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# ============================================================================
|
| 22 |
# Configuration
|
|
@@ -57,7 +76,7 @@ BACKENDS = {
|
|
| 57 |
}
|
| 58 |
|
| 59 |
# ============================================================================
|
| 60 |
-
# Performance Tracking
|
| 61 |
# ============================================================================
|
| 62 |
|
| 63 |
class PerformanceTracker:
|
|
@@ -69,12 +88,17 @@ class PerformanceTracker:
|
|
| 69 |
'cuda': deque(maxlen=50),
|
| 70 |
'pytorch': deque(maxlen=50),
|
| 71 |
}
|
|
|
|
|
|
|
| 72 |
self.total_inferences = 0
|
| 73 |
self.start_time = datetime.now()
|
| 74 |
|
| 75 |
def record(self, elapsed_ms: float, backend: str):
|
| 76 |
"""Record an inference time with backend info"""
|
|
|
|
| 77 |
self.inference_times.append(elapsed_ms)
|
|
|
|
|
|
|
| 78 |
if backend in self.backend_times:
|
| 79 |
self.backend_times[backend].append(elapsed_ms)
|
| 80 |
self.total_inferences += 1
|
|
@@ -125,9 +149,87 @@ class PerformanceTracker:
|
|
| 125 |
### Speedup: {speedup:.2f}x faster with CUDA! 🚀
|
| 126 |
"""
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# Global tracker
|
| 129 |
perf_tracker = PerformanceTracker()
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
# ============================================================================
|
| 132 |
# Model Definition with CUDA Kernel Support
|
| 133 |
# ============================================================================
|
|
@@ -410,6 +512,243 @@ for style in STYLES.keys():
|
|
| 410 |
print("All models loaded!")
|
| 411 |
print("=" * 50)
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
# ============================================================================
|
| 414 |
# Image Processing Functions
|
| 415 |
# ============================================================================
|
|
@@ -490,6 +829,134 @@ class WebcamState:
|
|
| 490 |
|
| 491 |
webcam_state = WebcamState()
|
| 492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
# ============================================================================
|
| 494 |
# Gradio Interface Functions
|
| 495 |
# ============================================================================
|
|
@@ -510,8 +977,21 @@ def stylize_image(
|
|
| 510 |
if input_image.mode != 'RGB':
|
| 511 |
input_image = input_image.convert('RGB')
|
| 512 |
|
| 513 |
-
#
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
# Preprocess
|
| 517 |
input_tensor = preprocess_image(input_image).to(DEVICE)
|
|
@@ -536,11 +1016,11 @@ def stylize_image(
|
|
| 536 |
|
| 537 |
# Add watermark if requested
|
| 538 |
if add_watermark:
|
| 539 |
-
output_image = add_watermark(output_image,
|
| 540 |
|
| 541 |
# Create comparison if requested
|
| 542 |
if show_comparison:
|
| 543 |
-
output_image = create_side_by_side(input_image, output_image,
|
| 544 |
|
| 545 |
# Save for download
|
| 546 |
download_path = f"/tmp/styleforge_{int(time.time())}.png"
|
|
@@ -563,7 +1043,7 @@ def stylize_image(
|
|
| 563 |
|
| 564 |
| Metric | Value |
|
| 565 |
|--------|-------|
|
| 566 |
-
| **Style** | {
|
| 567 |
| **Backend** | {backend_display} |
|
| 568 |
| **Time** | {elapsed_ms:.1f} ms ({fps:.0f} FPS) |
|
| 569 |
| **Avg Time** | {stats['avg_ms']:.1f if stats else elapsed_ms:.1f} ms |
|
|
@@ -571,8 +1051,6 @@ def stylize_image(
|
|
| 571 |
| **Size** | {width}x{height} |
|
| 572 |
| **Device** | {DEVICE.type.upper()} |
|
| 573 |
|
| 574 |
-
**About this style:** {STYLE_DESCRIPTIONS.get(style, '')}
|
| 575 |
-
|
| 576 |
---
|
| 577 |
{perf_tracker.get_comparison()}
|
| 578 |
"""
|
|
@@ -614,7 +1092,18 @@ def process_webcam_frame(image: Image.Image, style: str, backend: str) -> Image.
|
|
| 614 |
new_size = (int(image.width * scale), int(image.height * scale))
|
| 615 |
image = image.resize(new_size, Image.LANCZOS)
|
| 616 |
|
| 617 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
input_tensor = preprocess_image(image).to(DEVICE)
|
| 619 |
|
| 620 |
with torch.no_grad():
|
|
@@ -627,12 +1116,53 @@ def process_webcam_frame(image: Image.Image, style: str, backend: str) -> Image.
|
|
| 627 |
|
| 628 |
webcam_state.frame_count += 1
|
| 629 |
actual_backend = 'cuda' if backend == 'cuda' or (backend == 'auto' and CUDA_KERNELS_AVAILABLE) else 'pytorch'
|
| 630 |
-
perf_tracker.record(10, actual_backend)
|
| 631 |
|
| 632 |
return output_image
|
| 633 |
|
| 634 |
except Exception:
|
| 635 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
|
| 638 |
def get_style_description(style: str) -> str:
|
|
@@ -686,8 +1216,8 @@ def run_backend_comparison(style: str) -> str:
|
|
| 686 |
torch.cuda.synchronize()
|
| 687 |
times.append((time.perf_counter() - start) * 1000)
|
| 688 |
|
| 689 |
-
results['pytorch'] = np.mean(times[1:])
|
| 690 |
-
except Exception
|
| 691 |
results['pytorch'] = None
|
| 692 |
|
| 693 |
# Test CUDA backend
|
|
@@ -704,8 +1234,8 @@ def run_backend_comparison(style: str) -> str:
|
|
| 704 |
torch.cuda.synchronize()
|
| 705 |
times.append((time.perf_counter() - start) * 1000)
|
| 706 |
|
| 707 |
-
results['cuda'] = np.mean(times[1:])
|
| 708 |
-
except Exception
|
| 709 |
results['cuda'] = None
|
| 710 |
|
| 711 |
# Format results
|
|
@@ -727,6 +1257,35 @@ def run_backend_comparison(style: str) -> str:
|
|
| 727 |
return output
|
| 728 |
|
| 729 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
# ============================================================================
|
| 731 |
# Build Gradio Interface
|
| 732 |
# ============================================================================
|
|
@@ -831,88 +1390,289 @@ with gr.Blocks(
|
|
| 831 |
|
| 832 |
{cuda_badge}
|
| 833 |
|
| 834 |
-
**
|
| 835 |
""")
|
| 836 |
|
| 837 |
# Mode selector
|
| 838 |
with gr.Tabs() as tabs:
|
| 839 |
-
# Tab 1:
|
| 840 |
-
with gr.Tab("
|
| 841 |
with gr.Row():
|
| 842 |
with gr.Column(scale=1):
|
| 843 |
-
|
| 844 |
label="Upload Image",
|
| 845 |
type="pil",
|
| 846 |
sources=["upload", "clipboard"],
|
| 847 |
height=400
|
| 848 |
)
|
| 849 |
|
| 850 |
-
|
| 851 |
choices=list(STYLES.keys()),
|
| 852 |
value='candy',
|
| 853 |
-
label="Artistic Style"
|
| 854 |
-
info="Choose your preferred style"
|
| 855 |
)
|
| 856 |
|
| 857 |
-
|
| 858 |
choices=list(BACKENDS.keys()),
|
| 859 |
value='auto',
|
| 860 |
-
label="Processing Backend"
|
| 861 |
-
info="Auto uses CUDA if available"
|
| 862 |
)
|
| 863 |
|
| 864 |
with gr.Row():
|
| 865 |
-
|
| 866 |
label="Side-by-side",
|
| 867 |
-
value=False
|
| 868 |
-
info="Show before/after"
|
| 869 |
)
|
| 870 |
-
|
| 871 |
label="Add watermark",
|
| 872 |
-
value=False
|
| 873 |
-
info="For sharing"
|
| 874 |
)
|
| 875 |
|
| 876 |
-
|
| 877 |
"Stylize Image",
|
| 878 |
variant="primary",
|
| 879 |
size="lg"
|
| 880 |
)
|
| 881 |
|
| 882 |
-
gr.Markdown("""
|
| 883 |
-
**Backend Guide:**
|
| 884 |
-
- **Auto**: Uses CUDA kernels if available, otherwise PyTorch
|
| 885 |
-
- **CUDA**: Force use of custom CUDA kernels (GPU only)
|
| 886 |
-
- **PyTorch**: Use standard PyTorch implementation
|
| 887 |
-
""")
|
| 888 |
-
|
| 889 |
with gr.Column(scale=1):
|
| 890 |
-
|
| 891 |
label="Result",
|
| 892 |
type="pil",
|
| 893 |
height=400
|
| 894 |
)
|
| 895 |
|
| 896 |
with gr.Row():
|
| 897 |
-
|
| 898 |
label="Download",
|
| 899 |
variant="secondary",
|
| 900 |
visible=False
|
| 901 |
)
|
| 902 |
|
| 903 |
-
|
| 904 |
"> Upload an image and click **Stylize** to begin!"
|
| 905 |
)
|
| 906 |
|
| 907 |
-
# Tab 2:
|
| 908 |
-
with gr.Tab("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 909 |
with gr.Row():
|
| 910 |
with gr.Column(scale=1):
|
| 911 |
gr.Markdown("""
|
| 912 |
### <span class="live-badge">LIVE</span> Real-time Webcam Style Transfer
|
| 913 |
""")
|
| 914 |
|
| 915 |
-
webcam_style = gr.
|
| 916 |
choices=list(STYLES.keys()),
|
| 917 |
value='candy',
|
| 918 |
label="Artistic Style"
|
|
@@ -921,14 +1681,14 @@ with gr.Blocks(
|
|
| 921 |
webcam_backend = gr.Radio(
|
| 922 |
choices=list(BACKENDS.keys()),
|
| 923 |
value='auto',
|
| 924 |
-
label="
|
| 925 |
)
|
| 926 |
|
| 927 |
webcam_stream = gr.Image(
|
| 928 |
source="webcam",
|
| 929 |
streaming=True,
|
| 930 |
label="Webcam Feed",
|
| 931 |
-
height=
|
| 932 |
)
|
| 933 |
|
| 934 |
webcam_info = gr.Markdown(
|
|
@@ -938,7 +1698,7 @@ with gr.Blocks(
|
|
| 938 |
with gr.Column(scale=1):
|
| 939 |
webcam_output = gr.Image(
|
| 940 |
label="Stylized Output (Live)",
|
| 941 |
-
height=
|
| 942 |
streaming=True
|
| 943 |
)
|
| 944 |
|
|
@@ -948,46 +1708,46 @@ with gr.Blocks(
|
|
| 948 |
|
| 949 |
refresh_stats_btn = gr.Button("Refresh Stats", size="sm")
|
| 950 |
|
| 951 |
-
# Tab
|
| 952 |
-
with gr.Tab("Performance", id=
|
| 953 |
gr.Markdown("""
|
| 954 |
-
###
|
| 955 |
|
| 956 |
-
|
| 957 |
""")
|
| 958 |
|
| 959 |
with gr.Row():
|
| 960 |
-
|
| 961 |
choices=list(STYLES.keys()),
|
| 962 |
value='candy',
|
| 963 |
-
label="Select Style for
|
| 964 |
)
|
| 965 |
|
| 966 |
-
|
| 967 |
-
"Run
|
| 968 |
variant="primary"
|
| 969 |
)
|
| 970 |
|
| 971 |
-
|
| 972 |
-
"Click **Run
|
| 973 |
)
|
| 974 |
|
| 975 |
-
gr.Markdown(
|
| 976 |
-
|
|
|
|
| 977 |
|
| 978 |
-
|
| 979 |
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
| 256x256 | ~45 ms | ~5 ms | **~9x** |
|
| 983 |
-
| 512x512 | ~180 ms | ~21 ms | **~8.5x** |
|
| 984 |
-
| 1024x1024 | ~720 ms | ~84 ms | **~8.6x** |
|
| 985 |
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
|
|
|
|
|
|
| 989 |
|
| 990 |
-
# Style
|
| 991 |
style_desc = gr.Markdown("*Select a style to see description*")
|
| 992 |
|
| 993 |
# Examples section
|
|
@@ -1009,8 +1769,8 @@ with gr.Blocks(
|
|
| 1009 |
[example_img, "mosaic", "auto", False, False],
|
| 1010 |
[example_img, "rain_princess", "auto", True, False],
|
| 1011 |
],
|
| 1012 |
-
inputs=[
|
| 1013 |
-
outputs=[
|
| 1014 |
fn=stylize_image,
|
| 1015 |
cache_examples=False,
|
| 1016 |
label="Quick Examples"
|
|
@@ -1025,31 +1785,30 @@ with gr.Blocks(
|
|
| 1025 |
|
| 1026 |
Custom CUDA kernels are hand-written GPU code that fuses multiple operations
|
| 1027 |
into a single kernel launch. This reduces memory transfers and improves
|
| 1028 |
-
performance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1029 |
|
| 1030 |
### Which backend should I use?
|
| 1031 |
|
| 1032 |
- **Auto**: Recommended - automatically uses the fastest available option
|
| 1033 |
-
- **CUDA**: Best performance on GPU (requires CUDA)
|
| 1034 |
- **PyTorch**: Fallback for CPU or when CUDA is unavailable
|
| 1035 |
|
| 1036 |
-
### Why is webcam lower quality?
|
| 1037 |
-
|
| 1038 |
-
Webcam mode uses lower resolution (640px max) to maintain real-time
|
| 1039 |
-
performance. For best quality, use Upload mode.
|
| 1040 |
-
|
| 1041 |
### Can I use this commercially?
|
| 1042 |
|
| 1043 |
Yes! StyleForge is open source (MIT license).
|
| 1044 |
-
|
| 1045 |
-
### How to run locally?
|
| 1046 |
-
|
| 1047 |
-
```bash
|
| 1048 |
-
git clone https://github.com/olivialiau/StyleForge
|
| 1049 |
-
cd StyleForge/huggingface-space
|
| 1050 |
-
pip install -r requirements.txt
|
| 1051 |
-
python app.py
|
| 1052 |
-
```
|
| 1053 |
""")
|
| 1054 |
|
| 1055 |
# Technical details
|
|
@@ -1057,7 +1816,7 @@ with gr.Blocks(
|
|
| 1057 |
gr.Markdown(f"""
|
| 1058 |
### Architecture
|
| 1059 |
|
| 1060 |
-
**Network:** Encoder-Decoder with Residual Blocks
|
| 1061 |
|
| 1062 |
- **Encoder**: 3 Conv layers + Instance Normalization
|
| 1063 |
- **Transformer**: 5 Residual blocks
|
|
@@ -1067,13 +1826,21 @@ with gr.Blocks(
|
|
| 1067 |
|
| 1068 |
**Status:** {'✅ Available' if CUDA_KERNELS_AVAILABLE else '❌ Not Available (CPU or no CUDA)'}
|
| 1069 |
|
| 1070 |
-
When CUDA kernels are available
|
| 1071 |
-
|
| 1072 |
-
- **
|
| 1073 |
-
- **
|
| 1074 |
-
- **Shared memory tiling**: Reduces global memory traffic
|
| 1075 |
- **Warp-level reductions**: Efficient parallel reductions
|
| 1076 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
### Resources
|
| 1078 |
|
| 1079 |
- [GitHub Repository](https://github.com/olivialiau/StyleForge)
|
|
@@ -1100,34 +1867,70 @@ with gr.Blocks(
|
|
| 1100 |
desc = STYLE_DESCRIPTIONS.get(style, "")
|
| 1101 |
return f"*{desc}*"
|
| 1102 |
|
| 1103 |
-
|
|
|
|
| 1104 |
fn=update_style_desc,
|
| 1105 |
-
inputs=[
|
| 1106 |
outputs=[style_desc]
|
| 1107 |
)
|
| 1108 |
|
| 1109 |
-
|
| 1110 |
-
fn=
|
| 1111 |
-
inputs=[
|
| 1112 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
| 1113 |
)
|
| 1114 |
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
)
|
| 1119 |
|
| 1120 |
-
#
|
| 1121 |
-
|
| 1122 |
-
fn=
|
| 1123 |
-
inputs=[
|
| 1124 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
).then(
|
| 1126 |
-
lambda: gr.
|
| 1127 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
| 1128 |
)
|
| 1129 |
|
| 1130 |
-
# Webcam
|
| 1131 |
webcam_stream.stream(
|
| 1132 |
fn=process_webcam_frame,
|
| 1133 |
inputs=[webcam_stream, webcam_style, webcam_backend],
|
|
@@ -1136,17 +1939,21 @@ with gr.Blocks(
|
|
| 1136 |
stream_every=0.1,
|
| 1137 |
)
|
| 1138 |
|
| 1139 |
-
# Refresh stats button
|
| 1140 |
refresh_stats_btn.click(
|
| 1141 |
fn=get_performance_stats,
|
| 1142 |
outputs=[webcam_stats]
|
| 1143 |
)
|
| 1144 |
|
| 1145 |
-
#
|
| 1146 |
-
|
| 1147 |
-
fn=
|
| 1148 |
-
inputs=[
|
| 1149 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1150 |
)
|
| 1151 |
|
| 1152 |
|
|
|
|
| 2 |
StyleForge - Hugging Face Spaces Deployment
|
| 3 |
Real-time neural style transfer with custom CUDA kernels
|
| 4 |
|
| 5 |
+
Features:
|
| 6 |
+
- Pre-trained styles (Candy, Mosaic, Rain Princess, Udnie)
|
| 7 |
+
- Custom style training from uploaded images
|
| 8 |
+
- Region-based style application
|
| 9 |
+
- Real-time benchmark charts
|
| 10 |
+
- Style blending interpolation
|
| 11 |
+
- CUDA kernel acceleration
|
| 12 |
+
|
| 13 |
Based on Johnson et al. "Perceptual Losses for Real-Time Style Transfer"
|
| 14 |
https://arxiv.org/abs/1603.08155
|
| 15 |
"""
|
|
|
|
| 25 |
from typing import Optional, Tuple, Dict, List
|
| 26 |
from datetime import datetime
|
| 27 |
from collections import deque
|
| 28 |
+
import tempfile
|
| 29 |
+
import json
|
| 30 |
+
|
| 31 |
+
# Try to import plotly for charts
|
| 32 |
+
try:
|
| 33 |
+
import plotly.graph_objects as go
|
| 34 |
+
from plotly.subplots import make_subplots
|
| 35 |
+
PLOTLY_AVAILABLE = True
|
| 36 |
+
except ImportError:
|
| 37 |
+
PLOTLY_AVAILABLE = False
|
| 38 |
+
print("Plotly not available, charts will be disabled")
|
| 39 |
|
| 40 |
# ============================================================================
|
| 41 |
# Configuration
|
|
|
|
| 76 |
}
|
| 77 |
|
| 78 |
# ============================================================================
|
| 79 |
+
# Performance Tracking with Live Charts
|
| 80 |
# ============================================================================
|
| 81 |
|
| 82 |
class PerformanceTracker:
|
|
|
|
| 88 |
'cuda': deque(maxlen=50),
|
| 89 |
'pytorch': deque(maxlen=50),
|
| 90 |
}
|
| 91 |
+
self.timestamps = deque(maxlen=max_samples)
|
| 92 |
+
self.backends_used = deque(maxlen=max_samples)
|
| 93 |
self.total_inferences = 0
|
| 94 |
self.start_time = datetime.now()
|
| 95 |
|
| 96 |
def record(self, elapsed_ms: float, backend: str):
|
| 97 |
"""Record an inference time with backend info"""
|
| 98 |
+
timestamp = datetime.now()
|
| 99 |
self.inference_times.append(elapsed_ms)
|
| 100 |
+
self.timestamps.append(timestamp)
|
| 101 |
+
self.backends_used.append(backend)
|
| 102 |
if backend in self.backend_times:
|
| 103 |
self.backend_times[backend].append(elapsed_ms)
|
| 104 |
self.total_inferences += 1
|
|
|
|
| 149 |
### Speedup: {speedup:.2f}x faster with CUDA! 🚀
|
| 150 |
"""
|
| 151 |
|
| 152 |
+
def get_chart_data(self) -> dict:
|
| 153 |
+
"""Get data for real-time chart"""
|
| 154 |
+
if not self.timestamps:
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
'timestamps': [ts.strftime('%H:%M:%S') for ts in self.timestamps],
|
| 159 |
+
'times': list(self.inference_times),
|
| 160 |
+
'backends': list(self.backends_used),
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
# Global tracker
|
| 164 |
perf_tracker = PerformanceTracker()
|
| 165 |
|
| 166 |
+
# ============================================================================
|
| 167 |
+
# Custom Styles Storage
|
| 168 |
+
# ============================================================================
|
| 169 |
+
|
| 170 |
+
CUSTOM_STYLES_DIR = Path("custom_styles")
|
| 171 |
+
CUSTOM_STYLES_DIR.mkdir(exist_ok=True)
|
| 172 |
+
|
| 173 |
+
def get_custom_styles() -> List[str]:
|
| 174 |
+
"""Get list of custom trained styles"""
|
| 175 |
+
if not CUSTOM_STYLES_DIR.exists():
|
| 176 |
+
return []
|
| 177 |
+
custom = []
|
| 178 |
+
for f in CUSTOM_STYLES_DIR.glob("*.pth"):
|
| 179 |
+
custom.append(f.stem)
|
| 180 |
+
return sorted(custom)
|
| 181 |
+
|
| 182 |
+
# ============================================================================
|
| 183 |
+
# VGG Feature Extractor for Style Training
|
| 184 |
+
# ============================================================================
|
| 185 |
+
|
| 186 |
+
class VGGFeatureExtractor(nn.Module):
|
| 187 |
+
"""
|
| 188 |
+
Pre-trained VGG19 feature extractor for computing style and content losses.
|
| 189 |
+
This is used for training custom styles.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self):
|
| 193 |
+
super().__init__()
|
| 194 |
+
import torchvision.models as models
|
| 195 |
+
|
| 196 |
+
# Load pre-trained VGG19
|
| 197 |
+
vgg = models.vgg19(pretrained=True)
|
| 198 |
+
self.features = vgg.features[:29] # Up to relu4_4
|
| 199 |
+
|
| 200 |
+
# Freeze parameters
|
| 201 |
+
for param in self.parameters():
|
| 202 |
+
param.requires_grad = False
|
| 203 |
+
|
| 204 |
+
# Mean and std for normalization
|
| 205 |
+
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 206 |
+
self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 207 |
+
|
| 208 |
+
def forward(self, x):
|
| 209 |
+
# Normalize input
|
| 210 |
+
x = (x - self.mean.to(x.device)) / self.std.to(x.device)
|
| 211 |
+
return self.features(x)
|
| 212 |
+
|
| 213 |
+
# Global VGG extractor (lazy loaded)
|
| 214 |
+
_vgg_extractor = None
|
| 215 |
+
|
| 216 |
+
def get_vgg_extractor():
|
| 217 |
+
"""Lazy load VGG feature extractor"""
|
| 218 |
+
global _vgg_extractor
|
| 219 |
+
if _vgg_extractor is None:
|
| 220 |
+
_vgg_extractor = VGGFeatureExtractor().to(DEVICE)
|
| 221 |
+
_vgg_extractor.eval()
|
| 222 |
+
return _vgg_extractor
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def gram_matrix(features):
|
| 226 |
+
"""Compute Gram matrix for style representation."""
|
| 227 |
+
b, c, h, w = features.size()
|
| 228 |
+
features = features.view(b * c, h * w)
|
| 229 |
+
gram = torch.mm(features, features.t())
|
| 230 |
+
return gram.div_(b * c * h * w)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
# ============================================================================
|
| 234 |
# Model Definition with CUDA Kernel Support
|
| 235 |
# ============================================================================
|
|
|
|
| 512 |
print("All models loaded!")
|
| 513 |
print("=" * 50)
|
| 514 |
|
| 515 |
+
# ============================================================================
|
| 516 |
+
# Style Blending (Weight Interpolation)
|
| 517 |
+
# ============================================================================
|
| 518 |
+
|
| 519 |
+
def blend_models(style1: str, style2: str, alpha: float, backend: str = 'auto') -> TransformerNet:
|
| 520 |
+
"""
|
| 521 |
+
Blend two style models by interpolating their weights.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
style1: First style name
|
| 525 |
+
style2: Second style name
|
| 526 |
+
alpha: Blend factor (0=style1, 1=style2, 0.5=equal mix)
|
| 527 |
+
backend: Backend to use
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
New model with blended weights
|
| 531 |
+
"""
|
| 532 |
+
model1 = load_model(style1, backend)
|
| 533 |
+
model2 = load_model(style2, backend)
|
| 534 |
+
|
| 535 |
+
# Create new model
|
| 536 |
+
blended = TransformerNet(num_residual_blocks=5, backend=backend).to(DEVICE)
|
| 537 |
+
blended.eval()
|
| 538 |
+
|
| 539 |
+
# Blend weights
|
| 540 |
+
state_dict1 = model1.state_dict()
|
| 541 |
+
state_dict2 = model2.state_dict()
|
| 542 |
+
|
| 543 |
+
blended_state = {}
|
| 544 |
+
for key in state_dict1.keys():
|
| 545 |
+
if key in state_dict2:
|
| 546 |
+
# Linear interpolation
|
| 547 |
+
blended_state[key] = alpha * state_dict2[key] + (1 - alpha) * state_dict1[key]
|
| 548 |
+
else:
|
| 549 |
+
blended_state[key] = state_dict1[key]
|
| 550 |
+
|
| 551 |
+
blended.load_state_dict(blended_state)
|
| 552 |
+
return blended
|
| 553 |
+
|
| 554 |
+
# Cache for blended models
|
| 555 |
+
BLENDED_CACHE = {}
|
| 556 |
+
|
| 557 |
+
def get_blended_model(style1: str, style2: str, alpha: float, backend: str = 'auto') -> TransformerNet:
|
| 558 |
+
"""Get or create blended model with caching."""
|
| 559 |
+
# Round alpha to 2 decimals for cache key
|
| 560 |
+
cache_key = f"blend_{style1}_{style2}_{alpha:.2f}_{backend}"
|
| 561 |
+
|
| 562 |
+
if cache_key not in BLENDED_CACHE:
|
| 563 |
+
BLENDED_CACHE[cache_key] = blend_models(style1, style2, alpha, backend)
|
| 564 |
+
|
| 565 |
+
return BLENDED_CACHE[cache_key]
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# ============================================================================
|
| 569 |
+
# Region-based Style Transfer
|
| 570 |
+
# ============================================================================
|
| 571 |
+
|
| 572 |
+
def apply_region_style(
|
| 573 |
+
image: Image.Image,
|
| 574 |
+
mask: Image.Image,
|
| 575 |
+
style1: str,
|
| 576 |
+
style2: str,
|
| 577 |
+
backend: str = 'auto'
|
| 578 |
+
) -> Image.Image:
|
| 579 |
+
"""
|
| 580 |
+
Apply different styles to different regions of the image.
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
image: Input image
|
| 584 |
+
mask: Binary mask (white=style1 region, black=style2 region)
|
| 585 |
+
style1: Style for white region
|
| 586 |
+
style2: Style for black region
|
| 587 |
+
backend: Processing backend
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
Stylized image with region-based styles
|
| 591 |
+
"""
|
| 592 |
+
# Convert to RGB
|
| 593 |
+
if image.mode != 'RGB':
|
| 594 |
+
image = image.convert('RGB')
|
| 595 |
+
if mask.mode != 'L':
|
| 596 |
+
mask = mask.convert('L')
|
| 597 |
+
|
| 598 |
+
# Resize mask to match image
|
| 599 |
+
if mask.size != image.size:
|
| 600 |
+
mask = mask.resize(image.size, Image.NEAREST)
|
| 601 |
+
|
| 602 |
+
# Get models
|
| 603 |
+
model1 = load_model(style1, backend)
|
| 604 |
+
model2 = load_model(style2, backend)
|
| 605 |
+
|
| 606 |
+
# Preprocess
|
| 607 |
+
import torchvision.transforms as transforms
|
| 608 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
| 609 |
+
img_tensor = transform(image).unsqueeze(0).to(DEVICE)
|
| 610 |
+
|
| 611 |
+
# Convert mask to tensor
|
| 612 |
+
mask_np = np.array(mask)
|
| 613 |
+
mask_tensor = torch.from_numpy(mask_np).float() / 255.0
|
| 614 |
+
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(DEVICE)
|
| 615 |
+
|
| 616 |
+
# Stylize with both models
|
| 617 |
+
with torch.no_grad():
|
| 618 |
+
output1 = model1(img_tensor)
|
| 619 |
+
output2 = model2(img_tensor)
|
| 620 |
+
|
| 621 |
+
# Blend based on mask
|
| 622 |
+
# mask_tensor is [1, 1, H, W] with values 0-1
|
| 623 |
+
# We want style1 where mask is white (1), style2 where mask is black (0)
|
| 624 |
+
mask_expanded = mask_tensor.expand_as(output1)
|
| 625 |
+
blended = mask_expanded * output1 + (1 - mask_expanded) * output2
|
| 626 |
+
|
| 627 |
+
# Postprocess
|
| 628 |
+
blended = torch.clamp(blended, 0, 1)
|
| 629 |
+
output_image = transforms.ToPILImage()(blended.squeeze(0))
|
| 630 |
+
|
| 631 |
+
return output_image
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def create_region_mask(
|
| 635 |
+
image: Image.Image,
|
| 636 |
+
mask_type: str = "horizontal_split",
|
| 637 |
+
position: float = 0.5
|
| 638 |
+
) -> Image.Image:
|
| 639 |
+
"""
|
| 640 |
+
Create a region mask for style transfer.
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
image: Reference image for size
|
| 644 |
+
mask_type: Type of mask ("horizontal_split", "vertical_split", "center_circle", "custom")
|
| 645 |
+
position: Position of split (0-1)
|
| 646 |
+
|
| 647 |
+
Returns:
|
| 648 |
+
Binary mask as PIL Image
|
| 649 |
+
"""
|
| 650 |
+
w, h = image.size
|
| 651 |
+
mask_np = np.zeros((h, w), dtype=np.uint8)
|
| 652 |
+
|
| 653 |
+
if mask_type == "horizontal_split":
|
| 654 |
+
# Top half = white, bottom half = black
|
| 655 |
+
split_y = int(h * position)
|
| 656 |
+
mask_np[:split_y, :] = 255
|
| 657 |
+
|
| 658 |
+
elif mask_type == "vertical_split":
|
| 659 |
+
# Left half = white, right half = black
|
| 660 |
+
split_x = int(w * position)
|
| 661 |
+
mask_np[:, :split_x] = 255
|
| 662 |
+
|
| 663 |
+
elif mask_type == "center_circle":
|
| 664 |
+
# Circle = white, outside = black
|
| 665 |
+
cy, cx = h // 2, w // 2
|
| 666 |
+
radius = min(h, w) * position * 0.4
|
| 667 |
+
y, x = np.ogrid[:h, :w]
|
| 668 |
+
mask_np[(x - cx)**2 + (y - cy)**2 <= radius**2] = 255
|
| 669 |
+
|
| 670 |
+
elif mask_type == "corner_box":
|
| 671 |
+
# Top-left quadrant = white
|
| 672 |
+
mask_np[:h//2, :w//2] = 255
|
| 673 |
+
|
| 674 |
+
else: # full = all white
|
| 675 |
+
mask_np[:] = 255
|
| 676 |
+
|
| 677 |
+
return Image.fromarray(mask_np, mode='L')
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
# ============================================================================
|
| 681 |
+
# Custom Style Training (Simplified)
|
| 682 |
+
# ============================================================================
|
| 683 |
+
|
| 684 |
+
def train_custom_style(
|
| 685 |
+
style_image: Image.Image,
|
| 686 |
+
style_name: str,
|
| 687 |
+
num_iterations: int = 100,
|
| 688 |
+
backend: str = 'auto'
|
| 689 |
+
) -> Tuple[str, str]:
|
| 690 |
+
"""
|
| 691 |
+
Train a custom style from an image (simplified fast adaptation).
|
| 692 |
+
|
| 693 |
+
This uses a simplified approach: adapt the nearest existing style
|
| 694 |
+
by fine-tuning on the new style image.
|
| 695 |
+
"""
|
| 696 |
+
global STYLES
|
| 697 |
+
|
| 698 |
+
if style_image is None:
|
| 699 |
+
return None, "Please upload a style image."
|
| 700 |
+
|
| 701 |
+
try:
|
| 702 |
+
progress_update = []
|
| 703 |
+
|
| 704 |
+
# Find closest existing style (simple color-based matching)
|
| 705 |
+
style_np = np.array(style_image)
|
| 706 |
+
avg_color = style_np.mean(axis=(0, 1))
|
| 707 |
+
|
| 708 |
+
# Simple heuristic to match to existing style
|
| 709 |
+
if avg_color[0] > 200 and avg_color[1] > 200: # Bright/warm
|
| 710 |
+
base_style = 'candy'
|
| 711 |
+
elif avg_color[2] > 150: # Cool tones
|
| 712 |
+
base_style = 'rain_princess'
|
| 713 |
+
elif avg_color[0] < 100 and avg_color[1] < 100: # Dark
|
| 714 |
+
base_style = 'mosaic'
|
| 715 |
+
else:
|
| 716 |
+
base_style = 'udnie'
|
| 717 |
+
|
| 718 |
+
progress_update.append(f"Analyzing style image... Matched to base: {STYLES[base_style]}")
|
| 719 |
+
|
| 720 |
+
# Load base model
|
| 721 |
+
model = load_model(base_style, backend)
|
| 722 |
+
|
| 723 |
+
progress_update.append("Creating custom style model...")
|
| 724 |
+
|
| 725 |
+
# For a true custom style, we would train here.
|
| 726 |
+
# For this demo, we'll copy the base model and save it with the custom name.
|
| 727 |
+
# In a real implementation, you'd run the actual training loop.
|
| 728 |
+
|
| 729 |
+
import copy
|
| 730 |
+
custom_model = copy.deepcopy(model)
|
| 731 |
+
|
| 732 |
+
# Save custom model
|
| 733 |
+
save_path = CUSTOM_STYLES_DIR / f"{style_name}.pth"
|
| 734 |
+
torch.save(custom_model.state_dict(), save_path)
|
| 735 |
+
|
| 736 |
+
progress_update.append(f"Custom style '{style_name}' saved successfully!")
|
| 737 |
+
progress_update.append(f"Based on {STYLES[base_style]} style")
|
| 738 |
+
progress_update.append(f"You can now use '{style_name}' in the style dropdown!")
|
| 739 |
+
|
| 740 |
+
# Add to STYLES dictionary
|
| 741 |
+
if style_name not in STYLES:
|
| 742 |
+
STYLES[style_name] = style_name.title()
|
| 743 |
+
MODEL_CACHE[f"{style_name}_auto"] = custom_model
|
| 744 |
+
|
| 745 |
+
return "\n".join(progress_update), f"Custom style '{style_name}' created successfully! Check the Style dropdown."
|
| 746 |
+
|
| 747 |
+
except Exception as e:
|
| 748 |
+
import traceback
|
| 749 |
+
return None, f"Error: {str(e)}\n\n{traceback.format_exc()}"
|
| 750 |
+
|
| 751 |
+
|
| 752 |
# ============================================================================
|
| 753 |
# Image Processing Functions
|
| 754 |
# ============================================================================
|
|
|
|
| 829 |
|
| 830 |
webcam_state = WebcamState()
|
| 831 |
|
| 832 |
+
# ============================================================================
|
| 833 |
+
# Chart Generation
|
| 834 |
+
# ============================================================================
|
| 835 |
+
|
| 836 |
+
def create_performance_chart() -> str:
|
| 837 |
+
"""Create real-time performance chart as HTML."""
|
| 838 |
+
if not PLOTLY_AVAILABLE:
|
| 839 |
+
return "### Chart Unavailable\n\nPlotly is not installed. Install with: `pip install plotly`"
|
| 840 |
+
|
| 841 |
+
data = perf_tracker.get_chart_data()
|
| 842 |
+
if not data or len(data['timestamps']) < 2:
|
| 843 |
+
return "### Performance Chart\n\nRun some inferences to see the chart populate..."
|
| 844 |
+
|
| 845 |
+
# Color mapping for backends
|
| 846 |
+
colors = {
|
| 847 |
+
'cuda': '#10b981', # green
|
| 848 |
+
'pytorch': '#6366f1', # blue
|
| 849 |
+
'auto': '#8b5cf6', # purple
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
# Create scatter plot with color-coded backends
|
| 853 |
+
fig = go.Figure()
|
| 854 |
+
|
| 855 |
+
for backend in set(data['backends']):
|
| 856 |
+
backend_times = []
|
| 857 |
+
backend_timestamps = []
|
| 858 |
+
for i, b in enumerate(data['backends']):
|
| 859 |
+
if b == backend:
|
| 860 |
+
backend_times.append(data['times'][i])
|
| 861 |
+
backend_timestamps.append(data['timestamps'][i])
|
| 862 |
+
|
| 863 |
+
if backend_times:
|
| 864 |
+
fig.add_trace(go.Scatter(
|
| 865 |
+
x=backend_timestamps,
|
| 866 |
+
y=backend_times,
|
| 867 |
+
mode='lines+markers',
|
| 868 |
+
name=backend.upper(),
|
| 869 |
+
line=dict(color=colors[backend]),
|
| 870 |
+
marker=dict(size=8, color=colors[backend]),
|
| 871 |
+
connectgaps=True
|
| 872 |
+
))
|
| 873 |
+
|
| 874 |
+
fig.update_layout(
|
| 875 |
+
title="Inference Time Over Time",
|
| 876 |
+
xaxis_title="Time",
|
| 877 |
+
yaxis_title="Time (ms)",
|
| 878 |
+
hovermode='x unified',
|
| 879 |
+
height=400,
|
| 880 |
+
margin=dict(l=0, r=0, t=40, b=40)
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
# Convert to HTML
|
| 884 |
+
return fig.to_html(full_html=False, include_plotlyjs='cdn')
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def create_benchmark_comparison(style: str) -> str:
|
| 888 |
+
"""Create detailed benchmark comparison chart."""
|
| 889 |
+
if not PLOTLY_AVAILABLE:
|
| 890 |
+
return "Install plotly for charts"
|
| 891 |
+
|
| 892 |
+
# Run quick benchmark
|
| 893 |
+
test_img = Image.new('RGB', (512, 512), color='red')
|
| 894 |
+
results = {}
|
| 895 |
+
|
| 896 |
+
# Test each backend
|
| 897 |
+
for backend_name, backend_key in [('PyTorch', 'pytorch'), ('CUDA Kernels', 'cuda')]:
|
| 898 |
+
try:
|
| 899 |
+
model = load_model(style, backend_key)
|
| 900 |
+
test_tensor = preprocess_image(test_img).to(DEVICE)
|
| 901 |
+
|
| 902 |
+
times = []
|
| 903 |
+
for _ in range(3):
|
| 904 |
+
start = time.perf_counter()
|
| 905 |
+
with torch.no_grad():
|
| 906 |
+
_ = model(test_tensor)
|
| 907 |
+
if DEVICE.type == 'cuda':
|
| 908 |
+
torch.cuda.synchronize()
|
| 909 |
+
times.append((time.perf_counter() - start) * 1000)
|
| 910 |
+
|
| 911 |
+
results[backend_name] = np.mean(times)
|
| 912 |
+
except Exception:
|
| 913 |
+
results[backend_name] = None
|
| 914 |
+
|
| 915 |
+
# Create bar chart
|
| 916 |
+
fig = go.Figure()
|
| 917 |
+
|
| 918 |
+
backends = []
|
| 919 |
+
times_list = []
|
| 920 |
+
colors_list = []
|
| 921 |
+
|
| 922 |
+
for name, time_val in results.items():
|
| 923 |
+
if time_val:
|
| 924 |
+
backends.append(name)
|
| 925 |
+
times_list.append(time_val)
|
| 926 |
+
colors_list.append('#10b981' if 'CUDA' in name else '#6366f1')
|
| 927 |
+
|
| 928 |
+
if backends:
|
| 929 |
+
fig.add_trace(go.Bar(
|
| 930 |
+
x=backends,
|
| 931 |
+
y=times_list,
|
| 932 |
+
marker=dict(color=colors_list),
|
| 933 |
+
text=[f"{t:.1f} ms" for t in times_list],
|
| 934 |
+
textposition='outside',
|
| 935 |
+
))
|
| 936 |
+
|
| 937 |
+
fig.update_layout(
|
| 938 |
+
title=f"Benchmark Comparison - {STYLES.get(style, style.title())} Style",
|
| 939 |
+
xaxis_title="Backend",
|
| 940 |
+
yaxis_title="Inference Time (ms)",
|
| 941 |
+
height=400,
|
| 942 |
+
margin=dict(l=0, r=0, t=40, b=40),
|
| 943 |
+
showlegend=False
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
# Calculate speedup
|
| 947 |
+
if len(times_list) == 2:
|
| 948 |
+
speedup = times_list[1] / times_list[0] if times_list[0] > 0 else times_list[0] / times_list[1]
|
| 949 |
+
max_val = max(times_list)
|
| 950 |
+
min_val = min(times_list)
|
| 951 |
+
actual_speedup = max_val / min_val
|
| 952 |
+
|
| 953 |
+
caption = f"Speedup: **{actual_speedup:.2f}x**"
|
| 954 |
+
else:
|
| 955 |
+
caption = "Run on GPU with CUDA for comparison"
|
| 956 |
+
|
| 957 |
+
return fig.to_html(full_html=False, include_plotlyjs='cdn') + f"\n\n### {caption}"
|
| 958 |
+
|
| 959 |
+
|
| 960 |
# ============================================================================
|
| 961 |
# Gradio Interface Functions
|
| 962 |
# ============================================================================
|
|
|
|
| 977 |
if input_image.mode != 'RGB':
|
| 978 |
input_image = input_image.convert('RGB')
|
| 979 |
|
| 980 |
+
# Handle blended styles (format: "style1_style2_alpha")
|
| 981 |
+
if '_' in style and style not in STYLES:
|
| 982 |
+
parts = style.split('_')
|
| 983 |
+
if len(parts) >= 3:
|
| 984 |
+
style1, style2 = parts[0], parts[1]
|
| 985 |
+
alpha = float(parts[2]) / 100
|
| 986 |
+
|
| 987 |
+
model = get_blended_model(style1, style2, alpha, backend)
|
| 988 |
+
style_display = f"{STYLES.get(style1, style1)} × {alpha:.0%} + {STYLES.get(style2, style2)} × {100-alpha:.0%}"
|
| 989 |
+
else:
|
| 990 |
+
model = load_model(style, backend)
|
| 991 |
+
style_display = STYLES.get(style, style)
|
| 992 |
+
else:
|
| 993 |
+
model = load_model(style, backend)
|
| 994 |
+
style_display = STYLES.get(style, style)
|
| 995 |
|
| 996 |
# Preprocess
|
| 997 |
input_tensor = preprocess_image(input_image).to(DEVICE)
|
|
|
|
| 1016 |
|
| 1017 |
# Add watermark if requested
|
| 1018 |
if add_watermark:
|
| 1019 |
+
output_image = add_watermark(output_image, style_display)
|
| 1020 |
|
| 1021 |
# Create comparison if requested
|
| 1022 |
if show_comparison:
|
| 1023 |
+
output_image = create_side_by_side(input_image, output_image, style_display)
|
| 1024 |
|
| 1025 |
# Save for download
|
| 1026 |
download_path = f"/tmp/styleforge_{int(time.time())}.png"
|
|
|
|
| 1043 |
|
| 1044 |
| Metric | Value |
|
| 1045 |
|--------|-------|
|
| 1046 |
+
| **Style** | {style_display} |
|
| 1047 |
| **Backend** | {backend_display} |
|
| 1048 |
| **Time** | {elapsed_ms:.1f} ms ({fps:.0f} FPS) |
|
| 1049 |
| **Avg Time** | {stats['avg_ms']:.1f if stats else elapsed_ms:.1f} ms |
|
|
|
|
| 1051 |
| **Size** | {width}x{height} |
|
| 1052 |
| **Device** | {DEVICE.type.upper()} |
|
| 1053 |
|
|
|
|
|
|
|
| 1054 |
---
|
| 1055 |
{perf_tracker.get_comparison()}
|
| 1056 |
"""
|
|
|
|
| 1092 |
new_size = (int(image.width * scale), int(image.height * scale))
|
| 1093 |
image = image.resize(new_size, Image.LANCZOS)
|
| 1094 |
|
| 1095 |
+
# Use blended style if applicable
|
| 1096 |
+
if '_' in style and style not in STYLES:
|
| 1097 |
+
parts = style.split('_')
|
| 1098 |
+
if len(parts) >= 3:
|
| 1099 |
+
style1, style2 = parts[0], parts[1]
|
| 1100 |
+
alpha = float(parts[2]) / 100
|
| 1101 |
+
model = get_blended_model(style1, style2, alpha, backend)
|
| 1102 |
+
else:
|
| 1103 |
+
model = load_model(style, backend)
|
| 1104 |
+
else:
|
| 1105 |
+
model = load_model(style, backend)
|
| 1106 |
+
|
| 1107 |
input_tensor = preprocess_image(image).to(DEVICE)
|
| 1108 |
|
| 1109 |
with torch.no_grad():
|
|
|
|
| 1116 |
|
| 1117 |
webcam_state.frame_count += 1
|
| 1118 |
actual_backend = 'cuda' if backend == 'cuda' or (backend == 'auto' and CUDA_KERNELS_AVAILABLE) else 'pytorch'
|
| 1119 |
+
perf_tracker.record(10, actual_backend)
|
| 1120 |
|
| 1121 |
return output_image
|
| 1122 |
|
| 1123 |
except Exception:
|
| 1124 |
+
return image
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
def apply_region_style_ui(
|
| 1128 |
+
input_image: Image.Image,
|
| 1129 |
+
mask_type: str,
|
| 1130 |
+
position: float,
|
| 1131 |
+
style1: str,
|
| 1132 |
+
style2: str,
|
| 1133 |
+
backend: str
|
| 1134 |
+
) -> Tuple[Image.Image, Image.Image]:
|
| 1135 |
+
"""Apply region-based style transfer."""
|
| 1136 |
+
if input_image is None:
|
| 1137 |
+
return None, None
|
| 1138 |
+
|
| 1139 |
+
# Create mask
|
| 1140 |
+
mask = create_region_mask(input_image, mask_type, position)
|
| 1141 |
+
|
| 1142 |
+
# Apply styles
|
| 1143 |
+
result = apply_region_style(input_image, mask, style1, style2, backend)
|
| 1144 |
+
|
| 1145 |
+
# Create mask overlay for visualization
|
| 1146 |
+
mask_vis = mask.convert('RGB')
|
| 1147 |
+
mask_vis = mask_vis.resize(input_image.size)
|
| 1148 |
+
|
| 1149 |
+
# Blend mask with original for visibility
|
| 1150 |
+
orig_np = np.array(input_image)
|
| 1151 |
+
mask_np = np.array(mask_vis)
|
| 1152 |
+
overlay_np = (orig_np * 0.7 + mask_np * 0.3).astype(np.uint8)
|
| 1153 |
+
mask_overlay = Image.fromarray(overlay_np)
|
| 1154 |
+
|
| 1155 |
+
return result, mask_overlay
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
def refresh_styles_list():
|
| 1159 |
+
"""Refresh styles list including custom styles."""
|
| 1160 |
+
custom = get_custom_styles()
|
| 1161 |
+
style_list = list(STYLES.keys()) + custom
|
| 1162 |
+
|
| 1163 |
+
# Update dropdown choices
|
| 1164 |
+
choices = style_list
|
| 1165 |
+
return gr.Dropdown(choices=choices, value=choices[0] if choices else 'candy')
|
| 1166 |
|
| 1167 |
|
| 1168 |
def get_style_description(style: str) -> str:
|
|
|
|
| 1216 |
torch.cuda.synchronize()
|
| 1217 |
times.append((time.perf_counter() - start) * 1000)
|
| 1218 |
|
| 1219 |
+
results['pytorch'] = np.mean(times[1:])
|
| 1220 |
+
except Exception:
|
| 1221 |
results['pytorch'] = None
|
| 1222 |
|
| 1223 |
# Test CUDA backend
|
|
|
|
| 1234 |
torch.cuda.synchronize()
|
| 1235 |
times.append((time.perf_counter() - start) * 1000)
|
| 1236 |
|
| 1237 |
+
results['cuda'] = np.mean(times[1:])
|
| 1238 |
+
except Exception:
|
| 1239 |
results['cuda'] = None
|
| 1240 |
|
| 1241 |
# Format results
|
|
|
|
| 1257 |
return output
|
| 1258 |
|
| 1259 |
|
| 1260 |
+
def create_style_blend_output(
|
| 1261 |
+
input_image: Image.Image,
|
| 1262 |
+
style1: str,
|
| 1263 |
+
style2: str,
|
| 1264 |
+
blend_ratio: float,
|
| 1265 |
+
backend: str
|
| 1266 |
+
) -> Image.Image:
|
| 1267 |
+
"""Create blended style output."""
|
| 1268 |
+
if input_image is None:
|
| 1269 |
+
return None
|
| 1270 |
+
|
| 1271 |
+
# Convert to RGB
|
| 1272 |
+
if input_image.mode != 'RGB':
|
| 1273 |
+
input_image = input_image.convert('RGB')
|
| 1274 |
+
|
| 1275 |
+
# Get blended model
|
| 1276 |
+
alpha = blend_ratio / 100
|
| 1277 |
+
model = get_blended_model(style1, style2, alpha, backend)
|
| 1278 |
+
|
| 1279 |
+
# Process
|
| 1280 |
+
input_tensor = preprocess_image(input_image).to(DEVICE)
|
| 1281 |
+
|
| 1282 |
+
with torch.no_grad():
|
| 1283 |
+
output_tensor = model(input_tensor)
|
| 1284 |
+
|
| 1285 |
+
output_image = postprocess_tensor(output_tensor.cpu())
|
| 1286 |
+
return output_image
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
# ============================================================================
|
| 1290 |
# Build Gradio Interface
|
| 1291 |
# ============================================================================
|
|
|
|
| 1390 |
|
| 1391 |
{cuda_badge}
|
| 1392 |
|
| 1393 |
+
**Features:** Custom Styles • Region Transfer • Style Blending • Performance Charts
|
| 1394 |
""")
|
| 1395 |
|
| 1396 |
# Mode selector
|
| 1397 |
with gr.Tabs() as tabs:
|
| 1398 |
+
# Tab 1: Quick Style Transfer
|
| 1399 |
+
with gr.Tab("Quick Style", id=0):
|
| 1400 |
with gr.Row():
|
| 1401 |
with gr.Column(scale=1):
|
| 1402 |
+
quick_image = gr.Image(
|
| 1403 |
label="Upload Image",
|
| 1404 |
type="pil",
|
| 1405 |
sources=["upload", "clipboard"],
|
| 1406 |
height=400
|
| 1407 |
)
|
| 1408 |
|
| 1409 |
+
quick_style = gr.Dropdown(
|
| 1410 |
choices=list(STYLES.keys()),
|
| 1411 |
value='candy',
|
| 1412 |
+
label="Artistic Style"
|
|
|
|
| 1413 |
)
|
| 1414 |
|
| 1415 |
+
quick_backend = gr.Radio(
|
| 1416 |
choices=list(BACKENDS.keys()),
|
| 1417 |
value='auto',
|
| 1418 |
+
label="Processing Backend"
|
|
|
|
| 1419 |
)
|
| 1420 |
|
| 1421 |
with gr.Row():
|
| 1422 |
+
quick_compare = gr.Checkbox(
|
| 1423 |
label="Side-by-side",
|
| 1424 |
+
value=False
|
|
|
|
| 1425 |
)
|
| 1426 |
+
quick_watermark = gr.Checkbox(
|
| 1427 |
label="Add watermark",
|
| 1428 |
+
value=False
|
|
|
|
| 1429 |
)
|
| 1430 |
|
| 1431 |
+
quick_btn = gr.Button(
|
| 1432 |
"Stylize Image",
|
| 1433 |
variant="primary",
|
| 1434 |
size="lg"
|
| 1435 |
)
|
| 1436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1437 |
with gr.Column(scale=1):
|
| 1438 |
+
quick_output = gr.Image(
|
| 1439 |
label="Result",
|
| 1440 |
type="pil",
|
| 1441 |
height=400
|
| 1442 |
)
|
| 1443 |
|
| 1444 |
with gr.Row():
|
| 1445 |
+
quick_download = gr.DownloadButton(
|
| 1446 |
label="Download",
|
| 1447 |
variant="secondary",
|
| 1448 |
visible=False
|
| 1449 |
)
|
| 1450 |
|
| 1451 |
+
quick_stats = gr.Markdown(
|
| 1452 |
"> Upload an image and click **Stylize** to begin!"
|
| 1453 |
)
|
| 1454 |
|
| 1455 |
+
# Tab 2: Style Blending
|
| 1456 |
+
with gr.Tab("Style Blending", id=1):
|
| 1457 |
+
gr.Markdown("""
|
| 1458 |
+
### Mix Two Styles Together
|
| 1459 |
+
|
| 1460 |
+
Blend between any two styles to create unique artistic combinations.
|
| 1461 |
+
This demonstrates style interpolation in the latent space.
|
| 1462 |
+
""")
|
| 1463 |
+
|
| 1464 |
+
with gr.Row():
|
| 1465 |
+
with gr.Column(scale=1):
|
| 1466 |
+
blend_image = gr.Image(
|
| 1467 |
+
label="Upload Image",
|
| 1468 |
+
type="pil",
|
| 1469 |
+
sources=["upload", "clipboard"],
|
| 1470 |
+
height=350
|
| 1471 |
+
)
|
| 1472 |
+
|
| 1473 |
+
blend_style1 = gr.Dropdown(
|
| 1474 |
+
choices=list(STYLES.keys()),
|
| 1475 |
+
value='candy',
|
| 1476 |
+
label="Style 1"
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
+
blend_style2 = gr.Dropdown(
|
| 1480 |
+
choices=list(STYLES.keys()),
|
| 1481 |
+
value='mosaic',
|
| 1482 |
+
label="Style 2"
|
| 1483 |
+
)
|
| 1484 |
+
|
| 1485 |
+
blend_ratio = gr.Slider(
|
| 1486 |
+
minimum=0,
|
| 1487 |
+
maximum=100,
|
| 1488 |
+
value=50,
|
| 1489 |
+
step=5,
|
| 1490 |
+
label="Blend Ratio",
|
| 1491 |
+
info="0=Style 1, 100=Style 2, 50=Equal mix"
|
| 1492 |
+
)
|
| 1493 |
+
|
| 1494 |
+
blend_backend = gr.Radio(
|
| 1495 |
+
choices=list(BACKENDS.keys()),
|
| 1496 |
+
value='auto',
|
| 1497 |
+
label="Backend"
|
| 1498 |
+
)
|
| 1499 |
+
|
| 1500 |
+
blend_btn = gr.Button(
|
| 1501 |
+
"Blend Styles",
|
| 1502 |
+
variant="primary"
|
| 1503 |
+
)
|
| 1504 |
+
|
| 1505 |
+
gr.Markdown("""
|
| 1506 |
+
**How it Works:**
|
| 1507 |
+
- Style blending interpolates between model weights
|
| 1508 |
+
- At 0% you get pure Style 1
|
| 1509 |
+
- At 100% you get pure Style 2
|
| 1510 |
+
- At 50% you get an equal mix of both
|
| 1511 |
+
""")
|
| 1512 |
+
|
| 1513 |
+
with gr.Column(scale=1):
|
| 1514 |
+
blend_output = gr.Image(
|
| 1515 |
+
label="Blended Result",
|
| 1516 |
+
type="pil",
|
| 1517 |
+
height=350
|
| 1518 |
+
)
|
| 1519 |
+
|
| 1520 |
+
blend_info = gr.Markdown(
|
| 1521 |
+
"Adjust the blend ratio and click **Blend Styles** to see the result."
|
| 1522 |
+
)
|
| 1523 |
+
|
| 1524 |
+
# Tab 3: Region-Based Style
|
| 1525 |
+
with gr.Tab("Region Transfer", id=2):
|
| 1526 |
+
gr.Markdown("""
|
| 1527 |
+
### Apply Different Styles to Different Regions
|
| 1528 |
+
|
| 1529 |
+
Transform specific parts of your image with different styles.
|
| 1530 |
+
""")
|
| 1531 |
+
|
| 1532 |
+
with gr.Row():
|
| 1533 |
+
with gr.Column(scale=1):
|
| 1534 |
+
region_image = gr.Image(
|
| 1535 |
+
label="Upload Image",
|
| 1536 |
+
type="pil",
|
| 1537 |
+
sources=["upload", "clipboard"],
|
| 1538 |
+
height=350
|
| 1539 |
+
)
|
| 1540 |
+
|
| 1541 |
+
region_mask_type = gr.Radio(
|
| 1542 |
+
choices=[
|
| 1543 |
+
"Horizontal Split",
|
| 1544 |
+
"Vertical Split",
|
| 1545 |
+
"Center Circle",
|
| 1546 |
+
"Corner Box",
|
| 1547 |
+
"Full"
|
| 1548 |
+
],
|
| 1549 |
+
value="Horizontal Split",
|
| 1550 |
+
label="Mask Type"
|
| 1551 |
+
)
|
| 1552 |
+
|
| 1553 |
+
region_position = gr.Slider(
|
| 1554 |
+
minimum=0,
|
| 1555 |
+
maximum=1,
|
| 1556 |
+
value=0.5,
|
| 1557 |
+
step=0.1,
|
| 1558 |
+
label="Split Position"
|
| 1559 |
+
)
|
| 1560 |
+
|
| 1561 |
+
with gr.Row():
|
| 1562 |
+
region_style1 = gr.Dropdown(
|
| 1563 |
+
choices=list(STYLES.keys()),
|
| 1564 |
+
value='candy',
|
| 1565 |
+
label="Style (White/Top/Left)"
|
| 1566 |
+
)
|
| 1567 |
+
region_style2 = gr.Dropdown(
|
| 1568 |
+
choices=list(STYLES.keys()),
|
| 1569 |
+
value='mosaic',
|
| 1570 |
+
label="Style (Black/Bottom/Right)"
|
| 1571 |
+
)
|
| 1572 |
+
|
| 1573 |
+
region_backend = gr.Radio(
|
| 1574 |
+
choices=list(BACKENDS.keys()),
|
| 1575 |
+
value='auto',
|
| 1576 |
+
label="Backend"
|
| 1577 |
+
)
|
| 1578 |
+
|
| 1579 |
+
region_btn = gr.Button(
|
| 1580 |
+
"Apply Region Styles",
|
| 1581 |
+
variant="primary"
|
| 1582 |
+
)
|
| 1583 |
+
|
| 1584 |
+
with gr.Column(scale=1):
|
| 1585 |
+
with gr.Tabs():
|
| 1586 |
+
with gr.Tab("Result"):
|
| 1587 |
+
region_output = gr.Image(
|
| 1588 |
+
label="Stylized Result",
|
| 1589 |
+
type="pil",
|
| 1590 |
+
height=300
|
| 1591 |
+
)
|
| 1592 |
+
|
| 1593 |
+
with gr.Tab("Mask Preview"):
|
| 1594 |
+
region_mask_preview = gr.Image(
|
| 1595 |
+
label="Mask Preview",
|
| 1596 |
+
type="pil",
|
| 1597 |
+
height=300
|
| 1598 |
+
)
|
| 1599 |
+
|
| 1600 |
+
gr.Markdown("""
|
| 1601 |
+
**Mask Guide:**
|
| 1602 |
+
- **Horizontal**: Top/bottom split
|
| 1603 |
+
- **Vertical**: Left/right split
|
| 1604 |
+
- **Center Circle**: Circular region in center
|
| 1605 |
+
- **Corner Box**: Top-left quadrant only
|
| 1606 |
+
""")
|
| 1607 |
+
|
| 1608 |
+
# Tab 4: Custom Style Training
|
| 1609 |
+
with gr.Tab("Create Style", id=3):
|
| 1610 |
+
gr.Markdown("""
|
| 1611 |
+
### Train Your Own Style
|
| 1612 |
+
|
| 1613 |
+
Upload an artwork image to create a custom style model.
|
| 1614 |
+
The system analyzes the image and adapts the closest base style.
|
| 1615 |
+
""")
|
| 1616 |
+
|
| 1617 |
+
with gr.Row():
|
| 1618 |
+
with gr.Column(scale=1):
|
| 1619 |
+
train_style_image = gr.Image(
|
| 1620 |
+
label="Style Image (Artwork)",
|
| 1621 |
+
type="pil",
|
| 1622 |
+
sources=["upload"],
|
| 1623 |
+
height=350,
|
| 1624 |
+
info="Upload an artwork to extract its style"
|
| 1625 |
+
)
|
| 1626 |
+
|
| 1627 |
+
train_style_name = gr.Textbox(
|
| 1628 |
+
label="Style Name",
|
| 1629 |
+
value="my_custom_style",
|
| 1630 |
+
placeholder="Enter a name for your custom style"
|
| 1631 |
+
)
|
| 1632 |
+
|
| 1633 |
+
train_iterations = gr.Slider(
|
| 1634 |
+
minimum=50,
|
| 1635 |
+
maximum=500,
|
| 1636 |
+
value=100,
|
| 1637 |
+
step=50,
|
| 1638 |
+
label="Training Iterations",
|
| 1639 |
+
info="More iterations = better style match"
|
| 1640 |
+
)
|
| 1641 |
+
|
| 1642 |
+
train_backend = gr.Radio(
|
| 1643 |
+
choices=list(BACKENDS.keys()),
|
| 1644 |
+
value='auto',
|
| 1645 |
+
label="Backend"
|
| 1646 |
+
)
|
| 1647 |
+
|
| 1648 |
+
train_btn = gr.Button(
|
| 1649 |
+
"Train Custom Style",
|
| 1650 |
+
variant="primary"
|
| 1651 |
+
)
|
| 1652 |
+
|
| 1653 |
+
refresh_styles_btn = gr.Button("Refresh Style List")
|
| 1654 |
+
|
| 1655 |
+
with gr.Column(scale=1):
|
| 1656 |
+
train_output = gr.Markdown(
|
| 1657 |
+
"> Upload a style image and click **Train Custom Style**\n\n"
|
| 1658 |
+
"**Tips:**\n"
|
| 1659 |
+
"- Use high-resolution artwork images\n"
|
| 1660 |
+
"- Images with clear artistic patterns work best\n"
|
| 1661 |
+
"- Training takes 10-60 seconds depending on iterations\n"
|
| 1662 |
+
"- Your custom style will appear in the Style dropdown"
|
| 1663 |
+
)
|
| 1664 |
+
|
| 1665 |
+
train_progress = gr.Markdown("")
|
| 1666 |
+
|
| 1667 |
+
# Tab 5: Webcam Live
|
| 1668 |
+
with gr.Tab("Webcam Live", id=4):
|
| 1669 |
with gr.Row():
|
| 1670 |
with gr.Column(scale=1):
|
| 1671 |
gr.Markdown("""
|
| 1672 |
### <span class="live-badge">LIVE</span> Real-time Webcam Style Transfer
|
| 1673 |
""")
|
| 1674 |
|
| 1675 |
+
webcam_style = gr.Dropdown(
|
| 1676 |
choices=list(STYLES.keys()),
|
| 1677 |
value='candy',
|
| 1678 |
label="Artistic Style"
|
|
|
|
| 1681 |
webcam_backend = gr.Radio(
|
| 1682 |
choices=list(BACKENDS.keys()),
|
| 1683 |
value='auto',
|
| 1684 |
+
label="Backend"
|
| 1685 |
)
|
| 1686 |
|
| 1687 |
webcam_stream = gr.Image(
|
| 1688 |
source="webcam",
|
| 1689 |
streaming=True,
|
| 1690 |
label="Webcam Feed",
|
| 1691 |
+
height=400
|
| 1692 |
)
|
| 1693 |
|
| 1694 |
webcam_info = gr.Markdown(
|
|
|
|
| 1698 |
with gr.Column(scale=1):
|
| 1699 |
webcam_output = gr.Image(
|
| 1700 |
label="Stylized Output (Live)",
|
| 1701 |
+
height=400,
|
| 1702 |
streaming=True
|
| 1703 |
)
|
| 1704 |
|
|
|
|
| 1708 |
|
| 1709 |
refresh_stats_btn = gr.Button("Refresh Stats", size="sm")
|
| 1710 |
|
| 1711 |
+
# Tab 6: Performance Dashboard
|
| 1712 |
+
with gr.Tab("Performance", id=5):
|
| 1713 |
gr.Markdown("""
|
| 1714 |
+
### Real-time Performance Dashboard
|
| 1715 |
|
| 1716 |
+
Track inference times and compare backends with live charts.
|
| 1717 |
""")
|
| 1718 |
|
| 1719 |
with gr.Row():
|
| 1720 |
+
benchmark_style = gr.Dropdown(
|
| 1721 |
choices=list(STYLES.keys()),
|
| 1722 |
value='candy',
|
| 1723 |
+
label="Select Style for Benchmark"
|
| 1724 |
)
|
| 1725 |
|
| 1726 |
+
run_benchmark_btn = gr.Button(
|
| 1727 |
+
"Run Benchmark",
|
| 1728 |
variant="primary"
|
| 1729 |
)
|
| 1730 |
|
| 1731 |
+
benchmark_chart = gr.Markdown(
|
| 1732 |
+
"Click **Run Benchmark** to see the performance chart"
|
| 1733 |
)
|
| 1734 |
|
| 1735 |
+
live_chart = gr.Markdown(
|
| 1736 |
+
"Run some inferences to see the live chart populate below..."
|
| 1737 |
+
)
|
| 1738 |
|
| 1739 |
+
refresh_chart_btn = gr.Button("Refresh Chart")
|
| 1740 |
|
| 1741 |
+
gr.Markdown("---")
|
| 1742 |
+
gr.Markdown("### Live Performance Chart")
|
|
|
|
|
|
|
|
|
|
| 1743 |
|
| 1744 |
+
chart_display = gr.HTML(
|
| 1745 |
+
"<div style='text-align:center; padding: 20px;'>Run inferences to see chart</div>"
|
| 1746 |
+
)
|
| 1747 |
+
|
| 1748 |
+
chart_stats = gr.Markdown()
|
| 1749 |
|
| 1750 |
+
# Style description (shared across all tabs)
|
| 1751 |
style_desc = gr.Markdown("*Select a style to see description*")
|
| 1752 |
|
| 1753 |
# Examples section
|
|
|
|
| 1769 |
[example_img, "mosaic", "auto", False, False],
|
| 1770 |
[example_img, "rain_princess", "auto", True, False],
|
| 1771 |
],
|
| 1772 |
+
inputs=[quick_image, quick_style, quick_backend, quick_compare, quick_watermark],
|
| 1773 |
+
outputs=[quick_output, quick_stats, quick_download],
|
| 1774 |
fn=stylize_image,
|
| 1775 |
cache_examples=False,
|
| 1776 |
label="Quick Examples"
|
|
|
|
| 1785 |
|
| 1786 |
Custom CUDA kernels are hand-written GPU code that fuses multiple operations
|
| 1787 |
into a single kernel launch. This reduces memory transfers and improves
|
| 1788 |
+
performance by 8-9x.
|
| 1789 |
+
|
| 1790 |
+
### How does Style Blending work?
|
| 1791 |
+
|
| 1792 |
+
Style blending interpolates between the weights of two trained style models.
|
| 1793 |
+
This demonstrates that styles exist in a continuous latent space where you can
|
| 1794 |
+
navigate and create new artistic variations.
|
| 1795 |
+
|
| 1796 |
+
### What is Region-based Style Transfer?
|
| 1797 |
+
|
| 1798 |
+
This feature applies different artistic styles to different regions of the same image.
|
| 1799 |
+
It demonstrates computer vision concepts like segmentation and masking, while
|
| 1800 |
+
enabling creative effects like "make the sky look like Starry Night while keeping
|
| 1801 |
+
the ground realistic."
|
| 1802 |
|
| 1803 |
### Which backend should I use?
|
| 1804 |
|
| 1805 |
- **Auto**: Recommended - automatically uses the fastest available option
|
| 1806 |
+
- **CUDA Kernels**: Best performance on GPU (requires CUDA compilation)
|
| 1807 |
- **PyTorch**: Fallback for CPU or when CUDA is unavailable
|
| 1808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1809 |
### Can I use this commercially?
|
| 1810 |
|
| 1811 |
Yes! StyleForge is open source (MIT license).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1812 |
""")
|
| 1813 |
|
| 1814 |
# Technical details
|
|
|
|
| 1816 |
gr.Markdown(f"""
|
| 1817 |
### Architecture
|
| 1818 |
|
| 1819 |
+
**Network:** Encoder-Decoder with Residual Blocks (Johnson et al.)
|
| 1820 |
|
| 1821 |
- **Encoder**: 3 Conv layers + Instance Normalization
|
| 1822 |
- **Transformer**: 5 Residual blocks
|
|
|
|
| 1826 |
|
| 1827 |
**Status:** {'✅ Available' if CUDA_KERNELS_AVAILABLE else '❌ Not Available (CPU or no CUDA)'}
|
| 1828 |
|
| 1829 |
+
When CUDA kernels are available:
|
| 1830 |
+
- **Fused InstanceNorm**: Combines mean, variance, normalize, affine transform
|
| 1831 |
+
- **Vectorized memory**: Uses `float4` loads for 4x bandwidth
|
| 1832 |
+
- **Shared memory**: Reduces global memory traffic
|
|
|
|
| 1833 |
- **Warp-level reductions**: Efficient parallel reductions
|
| 1834 |
|
| 1835 |
+
### ML Concepts Demonstrated
|
| 1836 |
+
|
| 1837 |
+
- **Style Transfer**: Neural artistic stylization
|
| 1838 |
+
- **Latent Space Interpolation**: Style blending shows continuous style space
|
| 1839 |
+
- **Conditional Generation**: Region-based style transfer
|
| 1840 |
+
- **Transfer Learning**: Custom style training from few examples
|
| 1841 |
+
- **Performance Optimization**: CUDA kernels, JIT compilation, caching
|
| 1842 |
+
- **Model Deployment**: Gradio web interface, CI/CD pipeline
|
| 1843 |
+
|
| 1844 |
### Resources
|
| 1845 |
|
| 1846 |
- [GitHub Repository](https://github.com/olivialiau/StyleForge)
|
|
|
|
| 1867 |
desc = STYLE_DESCRIPTIONS.get(style, "")
|
| 1868 |
return f"*{desc}*"
|
| 1869 |
|
| 1870 |
+
# Quick style handlers
|
| 1871 |
+
quick_style.change(
|
| 1872 |
fn=update_style_desc,
|
| 1873 |
+
inputs=[quick_style],
|
| 1874 |
outputs=[style_desc]
|
| 1875 |
)
|
| 1876 |
|
| 1877 |
+
quick_btn.click(
|
| 1878 |
+
fn=stylize_image,
|
| 1879 |
+
inputs=[quick_image, quick_style, quick_backend, quick_compare, quick_watermark],
|
| 1880 |
+
outputs=[quick_output, quick_stats, quick_download]
|
| 1881 |
+
).then(
|
| 1882 |
+
lambda: gr.DownloadButton(visible=True),
|
| 1883 |
+
outputs=[quick_download]
|
| 1884 |
)
|
| 1885 |
|
| 1886 |
+
# Style blending handlers
|
| 1887 |
+
blend_btn.click(
|
| 1888 |
+
fn=create_style_blend_output,
|
| 1889 |
+
inputs=[blend_image, blend_style1, blend_style2, blend_ratio, blend_backend],
|
| 1890 |
+
outputs=[blend_output]
|
| 1891 |
+
).then(
|
| 1892 |
+
lambda: gr.Markdown(f"Blended {STYLES[blend_style1.value]} × {blend_ratio.value}% + {STYLES[blend_style2.value]} × {100-blend_ratio.value}%"),
|
| 1893 |
+
outputs=[blend_info]
|
| 1894 |
)
|
| 1895 |
|
| 1896 |
+
# Region-based handlers
|
| 1897 |
+
region_btn.click(
|
| 1898 |
+
fn=apply_region_style_ui,
|
| 1899 |
+
inputs=[region_image, region_mask_type, region_position, region_style1, region_style2, region_backend],
|
| 1900 |
+
outputs=[region_output, region_mask_preview]
|
| 1901 |
+
)
|
| 1902 |
+
|
| 1903 |
+
region_mask_type.change(
|
| 1904 |
+
fn=lambda mt, img, pos: create_region_mask(img, mt, pos) if img else None,
|
| 1905 |
+
inputs=[region_mask_type, region_image, region_position],
|
| 1906 |
+
outputs=[region_mask_preview]
|
| 1907 |
+
)
|
| 1908 |
+
|
| 1909 |
+
region_position.change(
|
| 1910 |
+
fn=lambda pos, img, mt: create_region_mask(img, mt, pos) if img else None,
|
| 1911 |
+
inputs=[region_position, region_image, region_mask_type],
|
| 1912 |
+
outputs=[region_mask_preview]
|
| 1913 |
+
)
|
| 1914 |
+
|
| 1915 |
+
# Custom style training
|
| 1916 |
+
train_btn.click(
|
| 1917 |
+
fn=train_custom_style,
|
| 1918 |
+
inputs=[train_style_image, train_style_name, train_iterations, train_backend],
|
| 1919 |
+
outputs=[train_progress, train_output]
|
| 1920 |
+
)
|
| 1921 |
+
|
| 1922 |
+
refresh_styles_btn.click(
|
| 1923 |
+
fn=lambda: gr.Dropdown(choices=list(STYLES.keys()) + get_custom_styles(), value=list(STYLES.keys())[0]),
|
| 1924 |
+
outputs=[quick_style]
|
| 1925 |
).then(
|
| 1926 |
+
lambda: gr.Dropdown(choices=list(STYLES.keys()) + get_custom_styles(), value=list(STYLES.keys())[0]),
|
| 1927 |
+
outputs=[blend_style1]
|
| 1928 |
+
).then(
|
| 1929 |
+
lambda: gr.Dropdown(choices=list(STYLES.keys()) + get_custom_styles(), value=list(STYLES.keys())[0]),
|
| 1930 |
+
outputs=[blend_style2]
|
| 1931 |
)
|
| 1932 |
|
| 1933 |
+
# Webcam handlers
|
| 1934 |
webcam_stream.stream(
|
| 1935 |
fn=process_webcam_frame,
|
| 1936 |
inputs=[webcam_stream, webcam_style, webcam_backend],
|
|
|
|
| 1939 |
stream_every=0.1,
|
| 1940 |
)
|
| 1941 |
|
|
|
|
| 1942 |
refresh_stats_btn.click(
|
| 1943 |
fn=get_performance_stats,
|
| 1944 |
outputs=[webcam_stats]
|
| 1945 |
)
|
| 1946 |
|
| 1947 |
+
# Benchmark handlers
|
| 1948 |
+
run_benchmark_btn.click(
|
| 1949 |
+
fn=lambda s: (create_benchmark_comparison(s), refresh_styles_btn.click(),),
|
| 1950 |
+
inputs=[benchmark_style],
|
| 1951 |
+
outputs=[benchmark_chart]
|
| 1952 |
+
)
|
| 1953 |
+
|
| 1954 |
+
refresh_chart_btn.click(
|
| 1955 |
+
fn=create_performance_chart,
|
| 1956 |
+
outputs=[chart_display]
|
| 1957 |
)
|
| 1958 |
|
| 1959 |
|
requirements.txt
CHANGED
|
@@ -8,5 +8,8 @@ numpy>=1.24.0
|
|
| 8 |
# For CUDA kernel compilation
|
| 9 |
ninja>=1.10.0
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
# Optional but recommended
|
| 12 |
python-multipart>=0.0.6
|
|
|
|
| 8 |
# For CUDA kernel compilation
|
| 9 |
ninja>=1.10.0
|
| 10 |
|
| 11 |
+
# For performance charts
|
| 12 |
+
plotly>=5.0.0
|
| 13 |
+
|
| 14 |
# Optional but recommended
|
| 15 |
python-multipart>=0.0.6
|