chrisjcc commited on
Commit
e363ea3
·
verified ·
1 Parent(s): 4bfde1f

Use library manually

Browse files
Files changed (1) hide show
  1. app.py +41 -12
app.py CHANGED
@@ -3,29 +3,57 @@ import io
3
  from PIL import Image
4
  import base64
5
 
 
 
 
6
  from transformers import pipeline
7
  import gradio as gr
8
 
9
- hf_api_key = os.environ['HF_API_KEY']
 
10
 
11
- # Text-to-image endpoint
12
- get_completion = pipeline("text-to-image", model="stable-diffusion-v1-5/stable-diffusion-v1-5")
 
 
 
 
 
13
 
 
 
 
 
 
 
14
 
15
  # A helper function to convert the PIL image to base64,
16
  # so you can send it to the API
17
- def base64_to_pil(img_base64):
18
- base64_decoded = base64.b64decode(img_base64)
19
- byte_stream = io.BytesIO(base64_decoded)
20
- pil_image = Image.open(byte_stream)
21
- return pil_image
 
 
 
 
 
22
 
23
- def generate(prompt):
24
- output = get_completion(prompt)
25
- result_image = base64_to_pil(output)
26
- return result_image
 
 
 
 
 
 
 
27
 
28
 
 
29
  with gr.Blocks() as demo:
30
  gr.Markdown("# Image Generation with Stable Diffusion")
31
  prompt = gr.Textbox(label="Your prompt")
@@ -45,6 +73,7 @@ with gr.Blocks() as demo:
45
 
46
  btn.click(fn=generate, inputs=[prompt,negative_prompt,steps,guidance,width,height], outputs=[output])
47
 
 
48
  demo.launch(
49
  share=True,
50
  #server_port=int(os.environ['PORT3'])
 
3
  from PIL import Image
4
  import base64
5
 
6
+ import torch
7
+ from diffusers import StableDiffusionPipeline
8
+
9
  from transformers import pipeline
10
  import gradio as gr
11
 
12
+ # Set Hugging Face API (needed for gated models)
13
+ hf_api_key = os.environ.get('HF_API_KEY')
14
 
15
+ # Load the Stable Diffusion pipeline
16
+ model_id = "runwayml/stable-diffusion-v1-5"
17
+ pipe = StableDiffusionPipeline.from_pretrained(
18
+ model_id,
19
+ torch_dtype=torch.float16, # Use float16 for better performance on GPU
20
+ use_auth_token=hf_api_key # Required for gated model
21
+ )
22
 
23
+ # Move pipeline to GPU if available
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ pipe = pipe.to(device)
26
+
27
+ # Text-to-image endpoint
28
+ #get_completion = pipeline("text-to-image", model="stable-diffusion-v1-5/stable-diffusion-v1-5")
29
 
30
  # A helper function to convert the PIL image to base64,
31
  # so you can send it to the API
32
+ #def base64_to_pil(img_base64):
33
+ # base64_decoded = base64.b64decode(img_base64)
34
+ # byte_stream = io.BytesIO(base64_decoded)
35
+ # pil_image = Image.open(byte_stream)
36
+ # return pil_image
37
+
38
+ #def generate(prompt):
39
+ # output = get_completion(prompt)
40
+ # result_image = base64_to_pil(output)
41
+ # return result_image
42
 
43
+ def generate(prompt, negative_prompt, steps, guidance, width, height):
44
+ # Generate image with Stable Diffusion
45
+ output = pipe(
46
+ prompt,
47
+ negative_prompt=negative_prompt,
48
+ num_inference_steps=int(steps),
49
+ guidance_scale=float(guidance),
50
+ width=int(width),
51
+ height=int(height)
52
+ )
53
+ return output.images[0] # Return the first generated image (PIL format)
54
 
55
 
56
+ # Create Gradio interface
57
  with gr.Blocks() as demo:
58
  gr.Markdown("# Image Generation with Stable Diffusion")
59
  prompt = gr.Textbox(label="Your prompt")
 
73
 
74
  btn.click(fn=generate, inputs=[prompt,negative_prompt,steps,guidance,width,height], outputs=[output])
75
 
76
+ # Launch the app
77
  demo.launch(
78
  share=True,
79
  #server_port=int(os.environ['PORT3'])