File size: 2,056 Bytes
d27a72c
 
 
 
 
ef1a1c7
632bce3
d27a72c
632bce3
 
 
 
212910b
632bce3
 
212910b
632bce3
 
212910b
632bce3
 
d27a72c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be76e4a
 
212910b
d27a72c
 
 
 
 
 
 
 
 
 
 
 
212910b
 
be76e4a
 
 
 
212910b
be76e4a
 
212910b
d27a72c
 
 
 
ff5d80b
d27a72c
04a4e2a
 
 
 
 
 
ff5d80b
04a4e2a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import banana_dev as banana
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import os
# import boto3

# model_key = os.environ.get("model_key")
# api_key = os.environ.get("api_key")
# aws_access_key_id = os.environ.get("aws_access_key_id")
# aws_secret_access_key = os.environ.get("aws_secret_access_key")

# #Create a session using AWS credentials
# session = boto3.Session(aws_access_key_id, aws_secret_access_key)

# #Create an S3 resource object using the session
# s3 = session.resource('s3')

# #Select your bucket
# bucket = s3.Bucket('bwlmonet')

model_inputs = {
  "endpoint": "txt2img",
  "params": {
    "prompt": "",
    "negative_prompt": "",
    "steps": 25,
    "sampler_name": "Euler a",
    "cfg_scale": 7.5,
    "seed": 42,
    "batch_size": 1,
    "n_iter": 1,
    "width": 768,
    "height": 768,
    "tiling": False
  }
}

# for obj in bucket.objects.all():
#   print(obj.key)

def stable_diffusion_txt2img(prompt, api_key, model_key, model_inputs):
  # Update the model_inputs with the provided prompt
  model_inputs["params"]["prompt"] = prompt

  # Run the model
  out = banana.run(api_key, model_key, model_inputs)

  # Process the output
  image_byte_string = out["modelOutputs"][0]["images"]
  image_encoded = image_byte_string[0].encode("utf-8")
  image_bytes = BytesIO(base64.b64decode(image_encoded))
  image = Image.open(image_bytes)

  # Save image to S3
  # key = f"{prompt}.png"
  # image.save(key)
  # with open(key, "rb") as data:
  #   bucket.put_object(Key=key, Body=data)
  
  # for obj in bucket.objects.all():
  #   print(obj.key)

  return image

# Gradio Interface
def generator(prompt):
  return stable_diffusion_txt2img(prompt, api_key, model_key, model_inputs), stable_diffusion_txt2img(prompt, api_key, model_key, model_inputs)

with gr.Blocks() as demo:
  prompt = gr.Textbox(label="Prompt")
  submit = gr.Button(label="Generate")
  image1 = gr.Image()
  image2 = gr.Image()

  submit.click(generator, inputs=[prompt], outputs=[image1, image2], api_name="mmsd")

demo.launch()