MMSD / app.py
DavisFlockaFlame's picture
Update app.py
632bce3
import banana_dev as banana
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import os
# import boto3
# model_key = os.environ.get("model_key")
# api_key = os.environ.get("api_key")
# aws_access_key_id = os.environ.get("aws_access_key_id")
# aws_secret_access_key = os.environ.get("aws_secret_access_key")
# #Create a session using AWS credentials
# session = boto3.Session(aws_access_key_id, aws_secret_access_key)
# #Create an S3 resource object using the session
# s3 = session.resource('s3')
# #Select your bucket
# bucket = s3.Bucket('bwlmonet')
model_inputs = {
"endpoint": "txt2img",
"params": {
"prompt": "",
"negative_prompt": "",
"steps": 25,
"sampler_name": "Euler a",
"cfg_scale": 7.5,
"seed": 42,
"batch_size": 1,
"n_iter": 1,
"width": 768,
"height": 768,
"tiling": False
}
}
# for obj in bucket.objects.all():
# print(obj.key)
def stable_diffusion_txt2img(prompt, api_key, model_key, model_inputs):
# Update the model_inputs with the provided prompt
model_inputs["params"]["prompt"] = prompt
# Run the model
out = banana.run(api_key, model_key, model_inputs)
# Process the output
image_byte_string = out["modelOutputs"][0]["images"]
image_encoded = image_byte_string[0].encode("utf-8")
image_bytes = BytesIO(base64.b64decode(image_encoded))
image = Image.open(image_bytes)
# Save image to S3
# key = f"{prompt}.png"
# image.save(key)
# with open(key, "rb") as data:
# bucket.put_object(Key=key, Body=data)
# for obj in bucket.objects.all():
# print(obj.key)
return image
# Gradio Interface
def generator(prompt):
return stable_diffusion_txt2img(prompt, api_key, model_key, model_inputs), stable_diffusion_txt2img(prompt, api_key, model_key, model_inputs)
with gr.Blocks() as demo:
prompt = gr.Textbox(label="Prompt")
submit = gr.Button(label="Generate")
image1 = gr.Image()
image2 = gr.Image()
submit.click(generator, inputs=[prompt], outputs=[image1, image2], api_name="mmsd")
demo.launch()