Olivia commited on
Commit
1f1a374
·
1 Parent(s): df2c623

Polish: Add performance tracking, download button, FAQ, mobile optimization

Browse files

- Add PerformanceTracker for inference statistics
- Add download button for stylized results
- Add optional watermark for social sharing
- Add style descriptions with dynamic updates
- Add comprehensive FAQ section
- Add mobile-responsive CSS
- Improve UI with gradient header and better styling
- Update README with badges and proper documentation

Files changed (2) hide show
  1. README.md +67 -17
  2. app.py +299 -143
README.md CHANGED
@@ -12,39 +12,89 @@ license: mit
12
 
13
  # StyleForge: Real-Time Neural Style Transfer
14
 
15
- Transform your images with artistic styles using fast neural style transfer.
16
 
17
- ## 🎨 Features
 
 
 
 
18
 
19
  - **4 Artistic Styles**: Candy, Mosaic, Rain Princess, and Udnie
20
  - **Real-Time Processing**: Fast inference on both CPU and GPU
21
  - **Simple Interface**: Just upload an image and select a style
22
- - **Comparison View**: Option to see side-by-side before/after
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- ## 🚀 How It Works
25
 
26
- This Space uses **Fast Neural Style Transfer** based on the paper by Johnson et al.
27
- Unlike slow optimization-based methods, this approach trains a separate network per style
28
- that can transform images in a single forward pass.
29
 
30
  ### Architecture
31
 
32
- - **Encoder**: 3 convolutional layers with Instance Normalization
33
- - **Transformer**: 5 residual blocks
34
- - **Decoder**: 3 upsampling layers with Instance Normalization
 
 
 
 
 
 
 
 
 
 
35
 
36
- ## 📚 Resources
 
 
 
37
 
38
- - [GitHub Repository](https://github.com/olivialiau/StyleForge)
39
- - [Paper: Perceptual Losses for Real-Time Style Transfer](https://arxiv.org/abs/1603.08155)
40
- - [Original Implementation](https://github.com/jcjohnson/fast-neural-style)
41
 
42
- ## 👤 Author
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  **Olivia** - USC Computer Science
45
 
46
  [GitHub](https://github.com/olivialiau/StyleForge)
47
 
48
- ## 📄 License
 
 
 
 
49
 
50
- MIT License
 
 
 
12
 
13
  # StyleForge: Real-Time Neural Style Transfer
14
 
15
+ Transform your photos into artwork using fast neural style transfer.
16
 
17
+ [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/olivialiau/styleforge)
18
+ [![GitHub](https://img.shields.io/badge/GitHub-StyleForge-blue?logo=github)](https://github.com/olivialiau/StyleForge)
19
+ [![License: MIT](https://img.shields.io/badge/License-MIT-purple.svg)](https://opensource.org/licenses/MIT)
20
+
21
+ ## Features
22
 
23
  - **4 Artistic Styles**: Candy, Mosaic, Rain Princess, and Udnie
24
  - **Real-Time Processing**: Fast inference on both CPU and GPU
25
  - **Simple Interface**: Just upload an image and select a style
26
+ - **Side-by-Side Comparison**: See before and after together
27
+ - **Download Results**: Save your stylized images
28
+ - **Watermark Option**: Add branding for social sharing
29
+
30
+ ## Quick Start
31
+
32
+ 1. **Upload** any image (JPG, PNG)
33
+ 2. **Select** an artistic style
34
+ 3. **Click** "Stylize Image"
35
+ 4. **Download** your result!
36
+
37
+ ## How It Works
38
 
39
+ StyleForge uses **Fast Neural Style Transfer** based on Johnson et al.'s paper "Perceptual Losses for Real-Time Style Transfer".
40
 
41
+ Unlike slow optimization-based methods, this uses pre-trained networks that transform images in milliseconds.
 
 
42
 
43
  ### Architecture
44
 
45
+ - **Encoder**: 3 Conv layers + Instance Normalization
46
+ - **Transformer**: 5 Residual blocks
47
+ - **Decoder**: 3 Upsample Conv layers + Instance Normalization
48
+
49
+ ### Performance
50
+
51
+ | Resolution | GPU | CPU |
52
+ |------------|-----|-----|
53
+ | 256x256 | ~5ms | ~50ms |
54
+ | 512x512 | ~15ms | ~150ms |
55
+ | 1024x1024 | ~50ms | ~500ms |
56
+
57
+ ## Styles
58
 
59
+ - 🍬 **Candy**: Bright, colorful pop-art style
60
+ - 🎨 **Mosaic**: Fragmented tile-like reconstruction
61
+ - 🌧️ **Rain Princess**: Moody impressionistic
62
+ - 🖼️ **Udnie**: Bold abstract expressionist
63
 
64
+ ## Run Locally
 
 
65
 
66
+ ```bash
67
+ git clone https://github.com/olivialiau/StyleForge
68
+ cd StyleForge/huggingface-space
69
+ pip install -r requirements.txt
70
+ python app.py
71
+ ```
72
+
73
+ Open http://localhost:7860 in your browser.
74
+
75
+ ## Embed in Your Website
76
+
77
+ ```html
78
+ <iframe
79
+ src="https://olivialiau-styleforge.hf.space"
80
+ frameborder="0"
81
+ width="100%"
82
+ height="800px"
83
+ ></iframe>
84
+ ```
85
+
86
+ ## Author
87
 
88
  **Olivia** - USC Computer Science
89
 
90
  [GitHub](https://github.com/olivialiau/StyleForge)
91
 
92
+ ## License
93
+
94
+ MIT License - see [LICENSE](LICENSE) for details
95
+
96
+ ## Acknowledgments
97
 
98
+ - [Johnson et al.](https://arxiv.org/abs/1603.08155) - Perceptual Losses for Real-Time Style Transfer
99
+ - [yakhyo](https://github.com/yakhyo/fast-neural-style-transfer) - Pre-trained model weights
100
+ - [Hugging Face](https://huggingface.co) - Spaces platform
app.py CHANGED
@@ -9,12 +9,14 @@ https://arxiv.org/abs/1603.08155
9
  import gradio as gr
10
  import torch
11
  import torch.nn as nn
12
- from PIL import Image
13
  import numpy as np
14
  import time
15
  import os
16
  from pathlib import Path
17
  from typing import Optional, Tuple
 
 
18
 
19
  # ============================================================================
20
  # Configuration
@@ -32,8 +34,51 @@ STYLES = {
32
  'udnie': 'Udnie',
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # ============================================================================
36
- # Model Definition (Simplified for HF Spaces deployment)
37
  # ============================================================================
38
 
39
 
@@ -116,12 +161,7 @@ class UpsampleConvLayer(nn.Module):
116
 
117
 
118
  class TransformerNet(nn.Module):
119
- """
120
- Fast Neural Style Transfer Network
121
-
122
- Args:
123
- num_residual_blocks: Number of residual blocks (default: 5)
124
- """
125
 
126
  def __init__(self, num_residual_blocks: int = 5):
127
  super().__init__()
@@ -173,36 +213,20 @@ class TransformerNet(nn.Module):
173
 
174
  # Create mapping for different naming conventions
175
  name_mapping = {
176
- "in1": "conv1.norm",
177
- "in2": "conv2.norm",
178
- "in3": "conv3.norm",
179
- "conv1.conv2d": "conv1.conv",
180
- "conv2.conv2d": "conv2.conv",
181
- "conv3.conv2d": "conv3.conv",
182
- "res1.conv1.conv2d": "residual_blocks.0.conv1.conv",
183
- "res1.in1": "residual_blocks.0.conv1.norm",
184
- "res1.conv2.conv2d": "residual_blocks.0.conv2.conv",
185
- "res1.in2": "residual_blocks.0.conv2.norm",
186
- "res2.conv1.conv2d": "residual_blocks.1.conv1.conv",
187
- "res2.in1": "residual_blocks.1.conv1.norm",
188
- "res2.conv2.conv2d": "residual_blocks.1.conv2.conv",
189
- "res2.in2": "residual_blocks.1.conv2.norm",
190
- "res3.conv1.conv2d": "residual_blocks.2.conv1.conv",
191
- "res3.in1": "residual_blocks.2.conv1.norm",
192
- "res3.conv2.conv2d": "residual_blocks.2.conv2.conv",
193
- "res3.in2": "residual_blocks.2.conv2.norm",
194
- "res4.conv1.conv2d": "residual_blocks.3.conv1.conv",
195
- "res4.in1": "residual_blocks.3.conv1.norm",
196
- "res4.conv2.conv2d": "residual_blocks.3.conv2.conv",
197
- "res4.in2": "residual_blocks.3.conv2.norm",
198
- "res5.conv1.conv2d": "residual_blocks.4.conv1.conv",
199
- "res5.in1": "residual_blocks.4.conv1.norm",
200
- "res5.conv2.conv2d": "residual_blocks.4.conv2.conv",
201
- "res5.in2": "residual_blocks.4.conv2.norm",
202
- "deconv1.conv2d": "deconv1.conv",
203
- "in4": "deconv1.norm",
204
- "deconv2.conv2d": "deconv2.conv",
205
- "in5": "deconv2.norm",
206
  "deconv3.conv2d": "deconv3.1",
207
  }
208
 
@@ -238,8 +262,6 @@ class TransformerNet(nn.Module):
238
  # ============================================================================
239
 
240
  MODEL_CACHE = {}
241
-
242
- # Pre-download models on startup (for Hugging Face Spaces)
243
  MODELS_DIR = Path("models")
244
  MODELS_DIR.mkdir(exist_ok=True)
245
 
@@ -249,7 +271,6 @@ def get_model_path(style: str) -> Path:
249
  model_path = MODELS_DIR / f"{style}.pth"
250
 
251
  if not model_path.exists():
252
- # Download from GitHub releases
253
  url_map = {
254
  'candy': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/candy.pth',
255
  'mosaic': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/mosaic.pth',
@@ -285,14 +306,19 @@ def load_model(style: str) -> TransformerNet:
285
 
286
 
287
  # Preload all models on startup
 
 
 
 
288
  print("Preloading models...")
289
  for style in STYLES.keys():
290
  try:
291
  load_model(style)
 
292
  except Exception as e:
293
- print(f"Warning: Could not load {style}: {e}")
294
- print("Models preloaded")
295
-
296
 
297
  # ============================================================================
298
  # Image Processing Functions
@@ -307,46 +333,69 @@ def preprocess_image(img: Image.Image) -> torch.Tensor:
307
 
308
  def postprocess_tensor(tensor: torch.Tensor) -> Image.Image:
309
  """Convert tensor to PIL Image."""
310
- # Remove batch dimension
311
  if tensor.dim() == 4:
312
  tensor = tensor.squeeze(0)
313
-
314
- # Clamp to valid range
315
  tensor = torch.clamp(tensor, 0, 1)
316
-
317
- # Convert to PIL
318
  transform = transforms.ToPILImage()
319
  return transform(tensor)
320
 
321
 
322
- def create_side_by_side(img1: Image.Image, img2: Image.Image) -> Image.Image:
323
  """Create side-by-side comparison."""
324
- from PIL import ImageDraw, ImageFont
325
-
326
- # Resize to same height
327
  if img1.size != img2.size:
328
  img2 = img2.resize(img1.size, Image.LANCZOS)
329
 
330
  w, h = img1.size
331
- combined = Image.new('RGB', (w * 2 + 20, h + 60), 'white')
332
 
333
- # Paste images
334
- combined.paste(img1, (0, 60))
335
- combined.paste(img2, (w + 20, 60))
336
 
337
- # Add labels
338
  draw = ImageDraw.Draw(combined)
339
  try:
340
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24)
 
341
  except:
342
- font = ImageFont.load_default()
 
 
 
 
343
 
344
- draw.text((w // 2, 20), "Original", fill='black', font=font, anchor='mm')
345
- draw.text((w * 1.5 + 10, 20), "Stylized", fill='black', font=font, anchor='mm')
 
346
 
347
  return combined
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  # ============================================================================
351
  # Gradio Interface Functions
352
  # ============================================================================
@@ -354,13 +403,12 @@ def create_side_by_side(img1: Image.Image, img2: Image.Image) -> Image.Image:
354
  def stylize_image(
355
  input_image: Optional[Image.Image],
356
  style: str,
357
- show_comparison: bool
358
- ) -> Tuple[Optional[Image.Image], str]:
359
- """
360
- Main stylization function for Gradio.
361
- """
362
  if input_image is None:
363
- return None, "Please upload an image first."
364
 
365
  try:
366
  # Convert to RGB if needed
@@ -384,30 +432,45 @@ def stylize_image(
384
 
385
  elapsed_ms = (time.perf_counter() - start) * 1000
386
 
 
 
 
387
  # Postprocess
388
  output_image = postprocess_tensor(output_tensor.cpu())
389
 
 
 
 
 
390
  # Create comparison if requested
391
  if show_comparison:
392
- output_image = create_side_by_side(input_image, output_image)
 
 
 
 
393
 
394
  # Generate stats
 
395
  fps = 1000 / elapsed_ms if elapsed_ms > 0 else 0
396
  width, height = input_image.size
397
 
398
- stats = f"""
399
- ### Performance Stats
400
 
401
  | Metric | Value |
402
  |--------|-------|
403
  | **Style** | {STYLES[style]} |
404
- | **Inference Time** | {elapsed_ms:.2f} ms |
405
- | **FPS** | {fps:.1f} |
 
406
  | **Image Size** | {width}x{height} |
407
  | **Device** | {DEVICE.type.upper()} |
 
 
408
  """
409
 
410
- return output_image, stats
411
 
412
  except Exception as e:
413
  import traceback
@@ -418,7 +481,7 @@ def stylize_image(
418
  **{str(e)}**
419
 
420
  <details>
421
- <summary>Error Details</summary>
422
 
423
  ```
424
  {error_details}
@@ -426,7 +489,12 @@ def stylize_image(
426
 
427
  </details>
428
  """
429
- return None, error_msg
 
 
 
 
 
430
 
431
 
432
  # ============================================================================
@@ -435,8 +503,8 @@ def stylize_image(
435
 
436
  custom_css = """
437
  .gradio-container {
438
- font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
439
- max-width: 1200px;
440
  margin: auto;
441
  }
442
 
@@ -444,17 +512,40 @@ custom_css = """
444
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
445
  border: none !important;
446
  color: white !important;
 
 
447
  }
448
 
449
  .gr-button-primary:hover {
450
  transform: translateY(-2px);
451
- box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
452
- transition: all 0.2s;
 
 
 
 
 
453
  }
454
 
455
  h1 {
456
  text-align: center;
457
- color: #2C3E50;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  }
459
 
460
  .footer {
@@ -464,6 +555,19 @@ h1 {
464
  border-top: 1px solid #eee;
465
  color: #666;
466
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  """
468
 
469
  with gr.Blocks(
@@ -477,36 +581,45 @@ with gr.Blocks(
477
 
478
  # Header
479
  gr.Markdown("""
480
- # StyleForge: Real-Time Neural Style Transfer
481
 
482
- Transform your images with artistic styles using fast neural style transfer.
483
 
484
- **Based on:** Johnson et al. "Perceptual Losses for Real-Time Style Transfer" ([arXiv:1603.08155](https://arxiv.org/abs/1603.08155))
485
  """)
486
 
 
 
 
487
  # Main interface
488
  with gr.Row():
489
  with gr.Column(scale=1):
490
  # Input controls
491
  input_image = gr.Image(
492
- label="Upload Your Image",
493
  type="pil",
494
  sources=["upload", "webcam", "clipboard"],
495
- height=400
496
  )
497
 
498
- style = gr.Dropdown(
499
  choices=list(STYLES.keys()),
500
  value='candy',
501
- label="Select Artistic Style",
502
- type="value"
503
  )
504
 
505
- show_comparison = gr.Checkbox(
506
- label="Show side-by-side comparison",
507
- value=False,
508
- info="Display original and stylized images together"
509
- )
 
 
 
 
 
 
510
 
511
  submit_btn = gr.Button(
512
  "Stylize Image",
@@ -514,112 +627,155 @@ with gr.Blocks(
514
  size="lg"
515
  )
516
 
 
517
  gr.Markdown("""
518
- ### Tips
519
- - Works best with images 256-1024px
520
- - Try different styles to find your favorite
521
- - GPU acceleration is available when supported
 
522
  """)
523
 
524
  with gr.Column(scale=1):
525
  # Output
526
  output_image = gr.Image(
527
- label="Stylized Result",
528
  type="pil",
529
- height=400
530
  )
531
 
 
 
 
 
 
 
 
532
  stats_text = gr.Markdown(
533
- "Upload an image and click **'Stylize Image'** to begin!"
534
  )
535
 
536
  # Examples section
537
  gr.Markdown("---")
538
- gr.Markdown("### Try These Examples")
539
 
540
- # Create a simple example image programmatically
541
  def create_example_image():
542
- """Create a simple example image for testing."""
543
- import numpy as np
544
- # Create a gradient image
545
  arr = np.zeros((256, 256, 3), dtype=np.uint8)
546
  for i in range(256):
547
- arr[:, i, 0] = i # Red gradient
548
- arr[:, i, 1] = 255 - i # Blue gradient
549
- arr[:, i, 2] = 128 # Constant green
550
  return Image.fromarray(arr)
551
 
552
  example_img = create_example_image()
553
 
554
  gr.Examples(
555
  examples=[
556
- [example_img, "candy", False],
557
- [example_img, "mosaic", False],
558
- [example_img, "rain_princess", True],
559
  ],
560
- inputs=[input_image, style, show_comparison],
561
- outputs=[output_image, stats_text],
562
  fn=stylize_image,
563
  cache_examples=False,
 
564
  )
565
 
566
- # Technical details
567
  gr.Markdown("---")
568
 
569
- with gr.Accordion("Technical Details", open=False):
570
  gr.Markdown("""
571
- ### Architecture
572
 
573
- Fast Neural Style Transfer uses a feed-forward network trained per style:
 
 
574
 
575
- **Network Architecture:**
576
- - **Encoder:** 3 convolutional layers with Instance Normalization
577
- - **Transformer:** 5 residual blocks
578
- - **Decoder:** 3 upsampling layers with Instance Normalization
579
 
580
- ### How It Works
 
 
581
 
582
- Unlike optimization-based style transfer (slow, ~seconds per image),
583
- this approach trains a separate network per style that can transform
584
- images in real-time (~milliseconds per image).
585
 
586
- 1. The network is trained on style images (e.g., Starry Night)
587
- 2. It learns a direct mapping from content photos to stylized outputs
588
- 3. At inference, it applies this transformation in a single forward pass
589
 
590
- ### Performance
591
 
592
- This model processes images significantly faster than traditional
593
- optimization-based style transfer while maintaining quality.
594
 
595
- | Resolution | Time (GPU) | Time (CPU) |
596
- |------------|------------|------------|
597
- | 256x256 | ~5ms | ~50ms |
598
- | 512x512 | ~15ms | ~150ms |
599
- | 1024x1024 | ~50ms | ~500ms |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  ### Resources
602
 
603
  - [GitHub Repository](https://github.com/olivialiau/StyleForge)
604
  - [Paper: Perceptual Losses for Real-Time Style Transfer](https://arxiv.org/abs/1603.08155)
605
- - [Original Implementation](https://github.com/jcjohnson/fast-neural-style)
606
  """)
607
 
608
  # Footer
609
  gr.Markdown("""
610
  <div class="footer">
611
  <p>
612
- <strong>StyleForge</strong> | USC Computer Science<br>
613
- Built with Hugging Face Spaces 🤗
 
614
  </p>
615
  </div>
616
  """)
617
 
618
- # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  submit_btn.click(
620
  fn=stylize_image,
621
- inputs=[input_image, style, show_comparison],
622
- outputs=[output_image, stats_text]
 
 
 
623
  )
624
 
625
 
 
9
  import gradio as gr
10
  import torch
11
  import torch.nn as nn
12
+ from PIL import Image, ImageDraw, ImageFont
13
  import numpy as np
14
  import time
15
  import os
16
  from pathlib import Path
17
  from typing import Optional, Tuple
18
+ from datetime import datetime
19
+ from collections import deque
20
 
21
  # ============================================================================
22
  # Configuration
 
34
  'udnie': 'Udnie',
35
  }
36
 
37
+ STYLE_DESCRIPTIONS = {
38
+ 'candy': 'Bright, colorful transformation inspired by pop art',
39
+ 'mosaic': 'Fragmented, tile-like artistic reconstruction',
40
+ 'rain_princess': 'Moody, impressionistic with subtle textures',
41
+ 'udnie': 'Bold, abstract expressionist style',
42
+ }
43
+
44
+ # ============================================================================
45
+ # Performance Tracking
46
+ # ============================================================================
47
+
48
+ class PerformanceTracker:
49
+ """Track and display Space performance metrics"""
50
+
51
+ def __init__(self, max_samples=100):
52
+ self.inference_times = deque(maxlen=max_samples)
53
+ self.total_inferences = 0
54
+ self.start_time = datetime.now()
55
+
56
+ def record(self, elapsed_ms):
57
+ """Record an inference time"""
58
+ self.inference_times.append(elapsed_ms)
59
+ self.total_inferences += 1
60
+
61
+ def get_stats(self):
62
+ """Get performance statistics"""
63
+ if not self.inference_times:
64
+ return None
65
+
66
+ times = list(self.inference_times)
67
+ uptime = (datetime.now() - self.start_time).total_seconds()
68
+
69
+ return {
70
+ 'avg_ms': sum(times) / len(times),
71
+ 'min_ms': min(times),
72
+ 'max_ms': max(times),
73
+ 'total_inferences': self.total_inferences,
74
+ 'uptime_hours': uptime / 3600,
75
+ }
76
+
77
+ # Global tracker
78
+ perf_tracker = PerformanceTracker()
79
+
80
  # ============================================================================
81
+ # Model Definition
82
  # ============================================================================
83
 
84
 
 
161
 
162
 
163
  class TransformerNet(nn.Module):
164
+ """Fast Neural Style Transfer Network"""
 
 
 
 
 
165
 
166
  def __init__(self, num_residual_blocks: int = 5):
167
  super().__init__()
 
213
 
214
  # Create mapping for different naming conventions
215
  name_mapping = {
216
+ "in1": "conv1.norm", "in2": "conv2.norm", "in3": "conv3.norm",
217
+ "conv1.conv2d": "conv1.conv", "conv2.conv2d": "conv2.conv", "conv3.conv2d": "conv3.conv",
218
+ "res1.conv1.conv2d": "residual_blocks.0.conv1.conv", "res1.in1": "residual_blocks.0.conv1.norm",
219
+ "res1.conv2.conv2d": "residual_blocks.0.conv2.conv", "res1.in2": "residual_blocks.0.conv2.norm",
220
+ "res2.conv1.conv2d": "residual_blocks.1.conv1.conv", "res2.in1": "residual_blocks.1.conv1.norm",
221
+ "res2.conv2.conv2d": "residual_blocks.1.conv2.conv", "res2.in2": "residual_blocks.1.conv2.norm",
222
+ "res3.conv1.conv2d": "residual_blocks.2.conv1.conv", "res3.in1": "residual_blocks.2.conv1.norm",
223
+ "res3.conv2.conv2d": "residual_blocks.2.conv2.conv", "res3.in2": "residual_blocks.2.conv2.norm",
224
+ "res4.conv1.conv2d": "residual_blocks.3.conv1.conv", "res4.in1": "residual_blocks.3.conv1.norm",
225
+ "res4.conv2.conv2d": "residual_blocks.3.conv2.conv", "res4.in2": "residual_blocks.3.conv2.norm",
226
+ "res5.conv1.conv2d": "residual_blocks.4.conv1.conv", "res5.in1": "residual_blocks.4.conv1.norm",
227
+ "res5.conv2.conv2d": "residual_blocks.4.conv2.conv", "res5.in2": "residual_blocks.4.conv2.norm",
228
+ "deconv1.conv2d": "deconv1.conv", "in4": "deconv1.norm",
229
+ "deconv2.conv2d": "deconv2.conv", "in5": "deconv2.norm",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  "deconv3.conv2d": "deconv3.1",
231
  }
232
 
 
262
  # ============================================================================
263
 
264
  MODEL_CACHE = {}
 
 
265
  MODELS_DIR = Path("models")
266
  MODELS_DIR.mkdir(exist_ok=True)
267
 
 
271
  model_path = MODELS_DIR / f"{style}.pth"
272
 
273
  if not model_path.exists():
 
274
  url_map = {
275
  'candy': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/candy.pth',
276
  'mosaic': 'https://github.com/yakhyo/fast-neural-style-transfer/releases/download/v1.0/mosaic.pth',
 
306
 
307
 
308
  # Preload all models on startup
309
+ print("=" * 50)
310
+ print("StyleForge - Initializing...")
311
+ print("=" * 50)
312
+ print(f"Device: {DEVICE.type.upper()}")
313
  print("Preloading models...")
314
  for style in STYLES.keys():
315
  try:
316
  load_model(style)
317
+ print(f" {STYLES[style]}: Ready")
318
  except Exception as e:
319
+ print(f" {STYLES[style]}: Failed - {e}")
320
+ print("All models loaded!")
321
+ print("=" * 50)
322
 
323
  # ============================================================================
324
  # Image Processing Functions
 
333
 
334
  def postprocess_tensor(tensor: torch.Tensor) -> Image.Image:
335
  """Convert tensor to PIL Image."""
 
336
  if tensor.dim() == 4:
337
  tensor = tensor.squeeze(0)
 
 
338
  tensor = torch.clamp(tensor, 0, 1)
 
 
339
  transform = transforms.ToPILImage()
340
  return transform(tensor)
341
 
342
 
343
+ def create_side_by_side(img1: Image.Image, img2: Image.Image, style_name: str) -> Image.Image:
344
  """Create side-by-side comparison."""
 
 
 
345
  if img1.size != img2.size:
346
  img2 = img2.resize(img1.size, Image.LANCZOS)
347
 
348
  w, h = img1.size
349
+ combined = Image.new('RGB', (w * 2 + 20, h + 70), 'white')
350
 
351
+ combined.paste(img1, (0, 70))
352
+ combined.paste(img2, (w + 20, 70))
 
353
 
 
354
  draw = ImageDraw.Draw(combined)
355
  try:
356
+ font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28)
357
+ font_label = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
358
  except:
359
+ font_title = ImageFont.load_default()
360
+ font_label = ImageFont.load_default()
361
+
362
+ # Style title
363
+ draw.text((w + 10, 20), f"Style: {style_name}", fill='#667eea', font=font_title)
364
 
365
+ # Labels
366
+ draw.text((w // 2, 50), "Original", fill='#555', font=font_label, anchor='mm')
367
+ draw.text((w * 1.5 + 10, 50), "Stylized", fill='#555', font=font_label, anchor='mm')
368
 
369
  return combined
370
 
371
 
372
+ def add_watermark(img: Image.Image, style_name: str) -> Image.Image:
373
+ """Add subtle watermark for social sharing."""
374
+ result = img.copy()
375
+ draw = ImageDraw.Draw(result)
376
+ w, h = result.size
377
+
378
+ text = f"StyleForge • {style_name}"
379
+ try:
380
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", int(w / 40))
381
+ except:
382
+ font = ImageFont.load_default()
383
+
384
+ # Get text size
385
+ bbox = draw.textbbox((0, 0), text, font=font)
386
+ text_w = bbox[2] - bbox[0]
387
+ text_h = bbox[3] - bbox[1]
388
+
389
+ # Semi-transparent background
390
+ overlay = Image.new('RGBA', (text_w + 20, text_h + 10), (0, 0, 0, 100))
391
+ result.paste(overlay, (w - text_w - 25, h - text_h - 15), overlay)
392
+
393
+ # Text
394
+ draw.text((w - text_w - 15, h - text_h - 10), text, fill=(255, 255, 255, 200), font=font)
395
+
396
+ return result
397
+
398
+
399
  # ============================================================================
400
  # Gradio Interface Functions
401
  # ============================================================================
 
403
  def stylize_image(
404
  input_image: Optional[Image.Image],
405
  style: str,
406
+ show_comparison: bool,
407
+ add_watermark: bool
408
+ ) -> Tuple[Optional[Image.Image], str, Optional[str]]:
409
+ """Main stylization function for Gradio."""
 
410
  if input_image is None:
411
+ return None, "Please upload an image first.", None
412
 
413
  try:
414
  # Convert to RGB if needed
 
432
 
433
  elapsed_ms = (time.perf_counter() - start) * 1000
434
 
435
+ # Record performance
436
+ perf_tracker.record(elapsed_ms)
437
+
438
  # Postprocess
439
  output_image = postprocess_tensor(output_tensor.cpu())
440
 
441
+ # Add watermark if requested
442
+ if add_watermark:
443
+ output_image = add_watermark(output_image, STYLES[style])
444
+
445
  # Create comparison if requested
446
  if show_comparison:
447
+ output_image = create_side_by_side(input_image, output_image, STYLES[style])
448
+
449
+ # Save for download
450
+ download_path = f"/tmp/styleforge_{int(time.time())}.png"
451
+ output_image.save(download_path, quality=95)
452
 
453
  # Generate stats
454
+ stats = perf_tracker.get_stats()
455
  fps = 1000 / elapsed_ms if elapsed_ms > 0 else 0
456
  width, height = input_image.size
457
 
458
+ stats_text = f"""
459
+ ### Performance
460
 
461
  | Metric | Value |
462
  |--------|-------|
463
  | **Style** | {STYLES[style]} |
464
+ | **This Image** | {elapsed_ms:.1f} ms ({fps:.0f} FPS) |
465
+ | **Average** | {stats['avg_ms']:.1f if stats else elapsed_ms:.1f} ms |
466
+ | **Total Processed** | {stats['total_inferences'] if stats else 1} images |
467
  | **Image Size** | {width}x{height} |
468
  | **Device** | {DEVICE.type.upper()} |
469
+
470
+ **About this style:** {STYLE_DESCRIPTIONS.get(style, '')}
471
  """
472
 
473
+ return output_image, stats_text, download_path
474
 
475
  except Exception as e:
476
  import traceback
 
481
  **{str(e)}**
482
 
483
  <details>
484
+ <summary>Show details</summary>
485
 
486
  ```
487
  {error_details}
 
489
 
490
  </details>
491
  """
492
+ return None, error_msg, None
493
+
494
+
495
+ def get_style_description(style: str) -> str:
496
+ """Get description for selected style."""
497
+ return STYLE_DESCRIPTIONS.get(style, "")
498
 
499
 
500
  # ============================================================================
 
503
 
504
  custom_css = """
505
  .gradio-container {
506
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
507
+ max-width: 1280px;
508
  margin: auto;
509
  }
510
 
 
512
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
513
  border: none !important;
514
  color: white !important;
515
+ font-weight: 600 !important;
516
+ transition: all 0.3s ease !important;
517
  }
518
 
519
  .gr-button-primary:hover {
520
  transform: translateY(-2px);
521
+ box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4) !important;
522
+ }
523
+
524
+ .gr-button-secondary {
525
+ background: #f3f4f6 !important;
526
+ color: #374151 !important;
527
+ border: 1px solid #e5e7eb !important;
528
  }
529
 
530
  h1 {
531
  text-align: center;
532
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
533
+ -webkit-background-clip: text;
534
+ -webkit-text-fill-color: transparent;
535
+ background-clip: text;
536
+ }
537
+
538
+ .style-card {
539
+ border: 2px solid #e5e7eb;
540
+ border-radius: 12px;
541
+ padding: 16px;
542
+ margin: 8px 0;
543
+ transition: all 0.2s;
544
+ }
545
+
546
+ .style-card:hover {
547
+ border-color: #667eea;
548
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.15);
549
  }
550
 
551
  .footer {
 
555
  border-top: 1px solid #eee;
556
  color: #666;
557
  }
558
+
559
+ /* Mobile optimization */
560
+ @media (max-width: 768px) {
561
+ .gradio-container {
562
+ padding: 1rem 0.5rem !important;
563
+ }
564
+ .gr-row {
565
+ flex-direction: column !important;
566
+ }
567
+ .gr-column {
568
+ width: 100% !important;
569
+ }
570
+ }
571
  """
572
 
573
  with gr.Blocks(
 
581
 
582
  # Header
583
  gr.Markdown("""
584
+ # StyleForge
585
 
586
+ ### Real-time neural style transfer. Transform your photos into artwork.
587
 
588
+ **Fast. Free. No sign-up required.**
589
  """)
590
 
591
+ # Style description box
592
+ style_desc_box = gr.Markdown("*Select a style to see description*")
593
+
594
  # Main interface
595
  with gr.Row():
596
  with gr.Column(scale=1):
597
  # Input controls
598
  input_image = gr.Image(
599
+ label="Upload Image",
600
  type="pil",
601
  sources=["upload", "webcam", "clipboard"],
602
+ height=350
603
  )
604
 
605
+ style = gr.Radio(
606
  choices=list(STYLES.keys()),
607
  value='candy',
608
+ label="Artistic Style",
609
+ info="Choose your preferred style"
610
  )
611
 
612
+ with gr.Row():
613
+ show_comparison = gr.Checkbox(
614
+ label="Side-by-side",
615
+ value=False,
616
+ info="Show before/after"
617
+ )
618
+ add_watermark = gr.Checkbox(
619
+ label="Add watermark",
620
+ value=False,
621
+ info="For sharing"
622
+ )
623
 
624
  submit_btn = gr.Button(
625
  "Stylize Image",
 
627
  size="lg"
628
  )
629
 
630
+ # Style preview hints
631
  gr.Markdown("""
632
+ **Style Guide:**
633
+ - 🍬 **Candy**: Bright, colorful pop-art style
634
+ - 🎨 **Mosaic**: Fragmented tile-like reconstruction
635
+ - 🌧️ **Rain Princess**: Moody impressionistic
636
+ - 🖼️ **Udnie**: Bold abstract expressionist
637
  """)
638
 
639
  with gr.Column(scale=1):
640
  # Output
641
  output_image = gr.Image(
642
+ label="Result",
643
  type="pil",
644
+ height=350
645
  )
646
 
647
+ with gr.Row():
648
+ download_btn = gr.DownloadButton(
649
+ label="Download",
650
+ variant="secondary",
651
+ visible=False
652
+ )
653
+
654
  stats_text = gr.Markdown(
655
+ "> Upload an image and click **Stylize** to begin!"
656
  )
657
 
658
  # Examples section
659
  gr.Markdown("---")
 
660
 
 
661
  def create_example_image():
662
+ """Create example image for testing."""
 
 
663
  arr = np.zeros((256, 256, 3), dtype=np.uint8)
664
  for i in range(256):
665
+ arr[:, i, 0] = i
666
+ arr[:, i, 1] = 255 - i
667
+ arr[:, i, 2] = 128
668
  return Image.fromarray(arr)
669
 
670
  example_img = create_example_image()
671
 
672
  gr.Examples(
673
  examples=[
674
+ [example_img, "candy", False, False],
675
+ [example_img, "mosaic", False, False],
676
+ [example_img, "rain_princess", True, False],
677
  ],
678
+ inputs=[input_image, style, show_comparison, add_watermark],
679
+ outputs=[output_image, stats_text, download_btn],
680
  fn=stylize_image,
681
  cache_examples=False,
682
+ label="Quick Examples"
683
  )
684
 
685
+ # FAQ Section
686
  gr.Markdown("---")
687
 
688
+ with gr.Accordion("FAQ & Help", open=False):
689
  gr.Markdown("""
690
+ ### How does this work?
691
 
692
+ StyleForge uses **Fast Neural Style Transfer** based on Johnson et al.'s research.
693
+ Unlike slow optimization methods, this uses pre-trained networks that transform
694
+ images in milliseconds.
695
 
696
+ ### Which image sizes work best?
 
 
 
697
 
698
+ - **Optimal**: 512-1024 pixels
699
+ - **Works with**: Any size (auto-resized)
700
+ - **Note**: Larger images take longer but produce better results
701
 
702
+ ### Why is the first request slow?
 
 
703
 
704
+ Hugging Face Spaces "sleep" after inactivity. The first request wakes it up
705
+ (~30 seconds). Subsequent requests are instant.
 
706
 
707
+ ### Can I use this commercially?
708
 
709
+ Yes! StyleForge is open source (MIT license). The pre-trained models are from
710
+ the [fast-neural-style-transfer](https://github.com/yakhyo/fast-neural-style-transfer) project.
711
 
712
+ ### How to run locally?
713
+
714
+ ```bash
715
+ git clone https://github.com/olivialiau/StyleForge
716
+ cd StyleForge/huggingface-space
717
+ pip install -r requirements.txt
718
+ python app.py
719
+ ```
720
+ """)
721
+
722
+ # Technical details
723
+ with gr.Accordion("Technical Details", open=False):
724
+ gr.Markdown("""
725
+ ### Architecture
726
+
727
+ **Network:** Encoder-Decoder with Residual Blocks
728
+
729
+ - **Encoder**: 3 Conv layers + Instance Normalization
730
+ - **Transformer**: 5 Residual blocks
731
+ - **Decoder**: 3 Upsample Conv layers + Instance Normalization
732
+
733
+ ### Performance Benchmarks
734
+
735
+ | Resolution | GPU | CPU |
736
+ |------------|-----|-----|
737
+ | 256x256 | ~5ms | ~50ms |
738
+ | 512x512 | ~15ms | ~150ms |
739
+ | 1024x1024 | ~50ms | ~500ms |
740
 
741
  ### Resources
742
 
743
  - [GitHub Repository](https://github.com/olivialiau/StyleForge)
744
  - [Paper: Perceptual Losses for Real-Time Style Transfer](https://arxiv.org/abs/1603.08155)
 
745
  """)
746
 
747
  # Footer
748
  gr.Markdown("""
749
  <div class="footer">
750
  <p>
751
+ <strong>StyleForge</strong> Created by Olivia • USC Computer Science<br>
752
+ <a href="https://github.com/olivialiau/StyleForge">GitHub</a>
753
+ Built with <a href="https://huggingface.co/spaces">Hugging Face Spaces</a> 🤗
754
  </p>
755
  </div>
756
  """)
757
 
758
+ # Style description updater
759
+ style.change(
760
+ fn=get_style_description,
761
+ inputs=[style],
762
+ outputs=[style_desc_box]
763
+ )
764
+
765
+ # Also update description on load
766
+ demo.load(
767
+ fn=lambda: gr.Markdown("*Bright, colorful pop-art style*"),
768
+ outputs=[style_desc_box]
769
+ )
770
+
771
+ # Main event handler
772
  submit_btn.click(
773
  fn=stylize_image,
774
+ inputs=[input_image, style, show_comparison, add_watermark],
775
+ outputs=[output_image, stats_text, download_btn]
776
+ ).then(
777
+ lambda: gr.DownloadButton(visible=True),
778
+ outputs=[download_btn]
779
  )
780
 
781