chrisjcc commited on
Commit
d01c487
·
verified ·
1 Parent(s): 34ffcd0

Try a different model provider

Browse files
Files changed (1) hide show
  1. app.py +33 -33
app.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  import base64
5
 
6
  import torch
7
- from diffusers import StableDiffusionPipeline
8
 
9
  #from transformers import pipeline
10
  import gradio as gr
@@ -13,48 +13,48 @@ import gradio as gr
13
  hf_api_key = os.environ.get('HF_API_KEY')
14
 
15
  # Load the Stable Diffusion pipeline
16
- model_id = "runwayml/stable-diffusion-v1-5"
17
- pipe = StableDiffusionPipeline.from_pretrained(
18
- model_id,
19
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU
20
- use_auth_token=hf_api_key # Required for gated model
21
- )
22
 
23
  # Move pipeline to GPU if available
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
- pipe = pipe.to(device)
26
 
27
  # Text-to-image endpoint
28
- #get_completion = pipeline("text-to-image", model="stable-diffusion-v1-5/stable-diffusion-v1-5")
29
 
30
  # A helper function to convert the PIL image to base64,
31
  # so you can send it to the API
32
- #def base64_to_pil(img_base64):
33
- # base64_decoded = base64.b64decode(img_base64)
34
- # byte_stream = io.BytesIO(base64_decoded)
35
- # pil_image = Image.open(byte_stream)
36
- # return pil_image
37
 
38
- #def generate(prompt):
39
- # output = get_completion(prompt)
40
- # result_image = base64_to_pil(output)
41
- # return result_image
42
 
43
- def generate(prompt, negative_prompt, steps, guidance, width, height):
44
- # Ensure width and height are multiples of 8 (required by Stable Diffusion)
45
- width = int(width) - (int(width) % 8)
46
- height = int(height) - (int(height) % 8)
47
 
48
- # Generate image with Stable Diffusion
49
- output = pipe(
50
- prompt,
51
- negative_prompt=negative_prompt or None, # Handle empty negative prompt
52
- num_inference_steps=int(steps),
53
- guidance_scale=float(guidance),
54
- width=width,
55
- height=height
56
- )
57
- return output.images[0] # Return the first generated image (PIL format)
58
 
59
  # Create Gradio interface
60
  with gr.Blocks() as demo:
 
4
  import base64
5
 
6
  import torch
7
+ #from diffusers import StableDiffusionPipeline
8
 
9
  #from transformers import pipeline
10
  import gradio as gr
 
13
  hf_api_key = os.environ.get('HF_API_KEY')
14
 
15
  # Load the Stable Diffusion pipeline
16
+ #model_id = "runwayml/stable-diffusion-v1-5"
17
+ #pipe = StableDiffusionPipeline.from_pretrained(
18
+ # model_id,
19
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU
20
+ # use_auth_token=hf_api_key # Required for gated model
21
+ #)
22
 
23
  # Move pipeline to GPU if available
24
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ #pipe = pipe.to(device)
26
 
27
  # Text-to-image endpoint
28
+ get_completion = pipeline("text-to-image", model="stabilityai/stable-diffusion-xl-base-1.0")
29
 
30
  # A helper function to convert the PIL image to base64,
31
  # so you can send it to the API
32
+ def base64_to_pil(img_base64):
33
+ base64_decoded = base64.b64decode(img_base64)
34
+ byte_stream = io.BytesIO(base64_decoded)
35
+ pil_image = Image.open(byte_stream)
36
+ return pil_image
37
 
38
+ def generate(prompt):
39
+ output = get_completion(prompt)
40
+ result_image = base64_to_pil(output)
41
+ return result_image
42
 
43
+ #def generate(prompt, negative_prompt, steps, guidance, width, height):
44
+ # # Ensure width and height are multiples of 8 (required by Stable Diffusion)
45
+ # width = int(width) - (int(width) % 8)
46
+ # height = int(height) - (int(height) % 8)
47
 
48
+ # # Generate image with Stable Diffusion
49
+ # output = pipe(
50
+ # prompt,
51
+ # negative_prompt=negative_prompt or None, # Handle empty negative prompt
52
+ # num_inference_steps=int(steps),
53
+ # guidance_scale=float(guidance),
54
+ # width=width,
55
+ # height=height
56
+ # )
57
+ # return output.images[0] # Return the first generated image (PIL format)
58
 
59
  # Create Gradio interface
60
  with gr.Blocks() as demo: