# Import necessary libraries import os import io import IPython.display from IPython.display import Image, display, HTML from PIL import Image import base64 import requests import json from dotenv import load_dotenv, find_dotenv # Load environment variables load_dotenv(find_dotenv()) hf_api_key = os.getenv('HF_API_KEY') endpoint_url = os.getenv('HF_API_TTI_BASE') # Function to get image completion from the API def get_completion(inputs, parameters=None, endpoint_url=endpoint_url): headers = { "Authorization": f"Bearer {hf_api_key}", "Content-Type": "application/json" } data = {"inputs": inputs} if parameters is not None: data.update({"parameters": parameters}) response = requests.post(endpoint_url, headers=headers, data=json.dumps(data)) if response.status_code != 200: raise Exception(f"Request failed: {response.status_code} - {response.text}") return response.content # Function to convert base64 or binary data to PIL image def base64_to_pil(img_data): if isinstance(img_data, bytes): byte_stream = io.BytesIO(img_data) else: base64_decoded = base64.b64decode(img_data) byte_stream = io.BytesIO(base64_decoded) pil_image = Image.open(byte_stream) return pil_image import gradio as gr # Gradio interface function def generate(prompt): output = get_completion(prompt) result_image = base64_to_pil(output) return result_image # Ensure all Gradio interfaces are closed before launching a new one gr.close_all() # Create the Gradio interface demo = gr.Interface( fn=generate, inputs=[gr.Textbox(label="Your prompt")], outputs=[gr.Image(label="Result")], title="Image Generation with Stable Diffusion", description="Generate any image with Stable Diffusion.", allow_flagging="never", examples=[ ["a dog in a park"], ["Astronaut riding a horse"] ] ) if __name__ == "__main__": demo.launch()