ar08 commited on
Commit
31de3ea
·
verified ·
1 Parent(s): e3d1556

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -35
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionImg2ImgPipeline
4
  from PIL import Image
5
- import numpy as np
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_id = "nitrosocke/Ghibli-Diffusion"
@@ -10,44 +10,104 @@ model_id = "nitrosocke/Ghibli-Diffusion"
10
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
11
  model_id,
12
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
13
- )
14
- pipe.to(device)
15
  pipe.enable_attention_slicing()
16
 
17
- def generate_ghibli_style_realtime(image, steps=25):
18
- prompt = "ghibli style portrait"
 
 
 
 
19
 
20
- def stream():
21
- def callback(step: int, timestep: int, latents):
22
- with torch.no_grad():
23
- img = pipe.decode_latents(latents)
24
- img = pipe.numpy_to_pil(img)[0]
25
- yield img # Stream each image!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Capture all images and stream via yield
28
- with torch.inference_mode():
29
- pipe(
30
- prompt=prompt,
31
- image=image,
32
- strength=0.6,
33
- guidance_scale=6.0,
34
- num_inference_steps=steps,
35
- callback=callback,
36
- callback_steps=1,
 
 
37
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- return stream()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- iface = gr.Interface(
42
- fn=generate_ghibli_style_realtime,
43
- inputs=[
44
- gr.Image(type="pil", label="Upload a photo"),
45
- gr.Slider(minimum=10, maximum=50, value=25, step=1, label="Inference Steps")
46
- ],
47
- outputs=gr.Image(label="Live Ghibli Transformation"),
48
- live=True,
49
- title="✨ Real-Time Ghibli Portrait ✨",
50
- description="Upload a photo and see the Ghibli transformation in real-time!"
51
- )
52
-
53
- iface.launch()
 
2
  import torch
3
  from diffusers import StableDiffusionImg2ImgPipeline
4
  from PIL import Image
5
+ from typing import List
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_id = "nitrosocke/Ghibli-Diffusion"
 
10
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
11
  model_id,
12
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
13
+ safety_checker=None
14
+ ).to(device)
15
  pipe.enable_attention_slicing()
16
 
17
+ styles = {
18
+ "Classic Ghibli": "ghibli style portrait",
19
+ "Spirited Forest": "studio ghibli mystical forest portrait, soft lighting",
20
+ "Windy Valley": "ghibli style sky valley portrait, dreamy atmosphere",
21
+ "Cozy Home": "ghibli style cozy cottage scene, warm tones"
22
+ }
23
 
24
+ def generate_final_image(
25
+ image: Image.Image,
26
+ style_choice: str,
27
+ steps: int,
28
+ history: List[Image.Image],
29
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
30
+ ) -> tuple:
31
+ prompt = styles.get(style_choice, "ghibli style portrait")
32
+
33
+ with torch.inference_mode():
34
+ output = pipe(
35
+ prompt=prompt,
36
+ image=image,
37
+ strength=0.6,
38
+ guidance_scale=6.5,
39
+ num_inference_steps=steps
40
+ )
41
+
42
+ final_image = output.images[0]
43
+ return final_image, history + [final_image]
44
 
45
+ # Rest of the Gradio interface code remains the same as previous version
46
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
47
+ gr.Markdown("## 🌸 **Ghibli Portrait Generator — Spicy Edition** 🌸")
48
+ gr.Markdown("Upload a photo and transform it into an anime scene in your favorite Ghibli style! 🎬✨")
49
+
50
+ with gr.Row():
51
+ with gr.Column(scale=1):
52
+ image_input = gr.Image(type="pil", label="📸 Upload your photo", height=300)
53
+ style_dropdown = gr.Dropdown(
54
+ list(styles.keys()),
55
+ label="🎨 Choose a Ghibli Style",
56
+ value="Classic Ghibli"
57
  )
58
+ steps_slider = gr.Slider(10, 100, value=100, step=1, label="✨ Inference Steps")
59
+ generate_btn = gr.Button("💫 Start Magic!", variant="primary")
60
+
61
+ with gr.Column(scale=2):
62
+ output_image = gr.Image(label="🌟 Final Ghibli Portrait", height=500)
63
+ with gr.Row():
64
+ download_btn = gr.Button("💾 Download Final Image")
65
+ clear_btn = gr.Button("🧹 Clear History")
66
+
67
+ with gr.Accordion("📜 Previous Generations", open=False):
68
+ history_gallery = gr.Gallery(
69
+ label="Your Ghibli Journey",
70
+ columns=4,
71
+ height="auto",
72
+ object_fit="contain"
73
+ )
74
+
75
+ history_state = gr.State([])
76
 
77
+ # Generation workflow
78
+ generate_btn.click(
79
+ generate_final_image,
80
+ [image_input, style_dropdown, steps_slider, history_state],
81
+ [output_image, history_state]
82
+ ).then(
83
+ lambda x: x[-4:], # Update gallery after generation
84
+ history_state,
85
+ history_gallery
86
+ )
87
+
88
+ # Image selection from history
89
+ history_gallery.select(
90
+ lambda evt: evt,
91
+ None,
92
+ image_input
93
+ )
94
+
95
+ # Download handler
96
+ download_btn.click(
97
+ lambda img: img,
98
+ output_image,
99
+ gr.File(label="⬇️ Your Ghibli Portrait")
100
+ )
101
+
102
+ # Clear history
103
+ clear_btn.click(
104
+ lambda: [],
105
+ None,
106
+ history_state
107
+ ).then(
108
+ lambda: None,
109
+ None,
110
+ history_gallery
111
+ )
112
 
113
+ iface.launch(share=True, debug=True