|
|
|
|
|
import os |
|
|
import io |
|
|
import IPython.display |
|
|
from IPython.display import Image, display, HTML |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import requests |
|
|
import json |
|
|
from dotenv import load_dotenv, find_dotenv |
|
|
|
|
|
|
|
|
load_dotenv(find_dotenv()) |
|
|
hf_api_key = os.getenv('HF_API_KEY') |
|
|
endpoint_url = os.getenv('HF_API_TTI_BASE') |
|
|
|
|
|
|
|
|
def get_completion(inputs, parameters=None, endpoint_url=endpoint_url): |
|
|
headers = { |
|
|
"Authorization": f"Bearer {hf_api_key}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
data = {"inputs": inputs} |
|
|
if parameters is not None: |
|
|
data.update({"parameters": parameters}) |
|
|
response = requests.post(endpoint_url, headers=headers, data=json.dumps(data)) |
|
|
if response.status_code != 200: |
|
|
raise Exception(f"Request failed: {response.status_code} - {response.text}") |
|
|
return response.content |
|
|
|
|
|
|
|
|
def base64_to_pil(img_data): |
|
|
if isinstance(img_data, bytes): |
|
|
byte_stream = io.BytesIO(img_data) |
|
|
else: |
|
|
base64_decoded = base64.b64decode(img_data) |
|
|
byte_stream = io.BytesIO(base64_decoded) |
|
|
pil_image = Image.open(byte_stream) |
|
|
return pil_image |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
def generate(prompt): |
|
|
output = get_completion(prompt) |
|
|
result_image = base64_to_pil(output) |
|
|
return result_image |
|
|
|
|
|
|
|
|
gr.close_all() |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=generate, |
|
|
inputs=[gr.Textbox(label="Your prompt")], |
|
|
outputs=[gr.Image(label="Result")], |
|
|
title="Image Generation with Stable Diffusion", |
|
|
description="Generate any image with Stable Diffusion.", |
|
|
allow_flagging="never", |
|
|
examples=[ |
|
|
["a dog in a park"], |
|
|
["Astronaut riding a horse"] |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|