dlaima commited on
Commit
296d205
·
verified ·
1 Parent(s): 2d661d3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import os
3
+ import io
4
+ import IPython.display
5
+ from IPython.display import Image, display, HTML
6
+ from PIL import Image
7
+ import base64
8
+ import requests
9
+ import json
10
+ from dotenv import load_dotenv, find_dotenv
11
+
12
+ # Load environment variables
13
+ load_dotenv(find_dotenv())
14
+ hf_api_key = os.getenv('HF_API_KEY')
15
+ endpoint_url = os.getenv('HF_API_TTI_BASE')
16
+
17
+ # Function to get image completion from the API
18
+ def get_completion(inputs, parameters=None, endpoint_url=endpoint_url):
19
+ headers = {
20
+ "Authorization": f"Bearer {hf_api_key}",
21
+ "Content-Type": "application/json"
22
+ }
23
+ data = {"inputs": inputs}
24
+ if parameters is not None:
25
+ data.update({"parameters": parameters})
26
+ response = requests.post(endpoint_url, headers=headers, data=json.dumps(data))
27
+ if response.status_code != 200:
28
+ raise Exception(f"Request failed: {response.status_code} - {response.text}")
29
+ return response.content
30
+
31
+ # Function to convert base64 or binary data to PIL image
32
+ def base64_to_pil(img_data):
33
+ if isinstance(img_data, bytes):
34
+ byte_stream = io.BytesIO(img_data)
35
+ else:
36
+ base64_decoded = base64.b64decode(img_data)
37
+ byte_stream = io.BytesIO(base64_decoded)
38
+ pil_image = Image.open(byte_stream)
39
+ return pil_image
40
+
41
+ import gradio as gr
42
+
43
+ # Gradio interface function
44
+ def generate(prompt):
45
+ output = get_completion(prompt)
46
+ result_image = base64_to_pil(output)
47
+ return result_image
48
+
49
+ # Ensure all Gradio interfaces are closed before launching a new one
50
+ gr.close_all()
51
+
52
+ # Create the Gradio interface
53
+ demo = gr.Interface(
54
+ fn=generate,
55
+ inputs=[gr.Textbox(label="Your prompt")],
56
+ outputs=[gr.Image(label="Result")],
57
+ title="Image Generation with Stable Diffusion",
58
+ description="Generate any image with Stable Diffusion.",
59
+ allow_flagging="never",
60
+ examples=[
61
+ ["a dog in a park"],
62
+ ["Astronaut riding a horse"]
63
+ ]
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ demo.launch()
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+