pradipraut737 commited on
Commit
f335656
Β·
verified Β·
1 Parent(s): 10b1b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -8
app.py CHANGED
@@ -1,18 +1,39 @@
 
1
  import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
- import torch
4
- import accelerate
5
 
 
 
 
6
 
 
7
  pipe = StableDiffusionPipeline.from_pretrained(
8
  "stabilityai/sd-turbo",
9
- torch_dtype=torch.float16,
10
- ).to("cuda" if torch.cuda.is_available() else "cpu")
 
 
11
 
 
 
 
 
 
12
 
 
 
 
 
13
 
14
- def generate(prompt):
15
- image = pipe(prompt).images[0]
16
- return image
17
 
18
- gr.Interface(fn=generate, inputs="text", outputs="image", title="Text to Image").launch()
 
 
 
 
 
 
 
 
1
+ import torch
2
  import gradio as gr
3
  from diffusers import StableDiffusionPipeline
4
+ from PIL import Image
5
+ import io
6
 
7
+ # πŸ”§ Set device and dtype
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ dtype = torch.float16 if device == "cuda" else torch.float32
10
 
11
+ # 🧠 Load SD Turbo pipeline
12
  pipe = StableDiffusionPipeline.from_pretrained(
13
  "stabilityai/sd-turbo",
14
+ torch_dtype=dtype,
15
+ )
16
+ pipe.to(device)
17
+ pipe.enable_attention_slicing()
18
 
19
+ # 🎯 Optimized image generation
20
+ def generate(prompt):
21
+ with torch.autocast(device_type=device):
22
+ # πŸ”» Generate smaller image directly
23
+ image = pipe(prompt, height=384, width=384).images[0]
24
 
25
+ # πŸ–ΌοΈ Optional: compress image for faster Gradio display
26
+ buffer = io.BytesIO()
27
+ image.save(buffer, format="JPEG", quality=70) # πŸ”» reduce quality to 70
28
+ buffer.seek(0)
29
 
30
+ return Image.open(buffer)
 
 
31
 
32
+ # πŸš€ Gradio UI
33
+ gr.Interface(
34
+ fn=generate,
35
+ inputs=gr.Textbox(label="Enter Prompt", placeholder="A futuristic cyberpunk city at night"),
36
+ outputs=gr.Image(label="Generated Image"),
37
+ title="🎨 Fast Text-to-Image Generator (SD Turbo)",
38
+ description="Optimized for speed and light resource use. Generates lower-res compressed images from text.",
39
+ ).launch()