Spaces:
Runtime error
Runtime error
| import websocket # websocket-client | |
| import uuid | |
| import json | |
| import urllib.request | |
| import urllib.parse | |
| import random | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import io | |
| import os | |
| import gradio as gr | |
| server_address = os.environ.get("URL_API") | |
| json_data=os.environ.get("JSON_API") | |
| client_id = str(uuid.uuid4()) | |
| def queue_prompt(prompt): | |
| p = {"prompt": prompt, "client_id": client_id} | |
| data = json.dumps(p, indent=4).encode('utf-8') # Prettify JSON for print | |
| req = urllib.request.Request(f"http://{server_address}/prompt", data=data) | |
| return json.loads(urllib.request.urlopen(req).read()) | |
| def get_image(filename, subfolder, folder_type): | |
| data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
| url_values = urllib.parse.urlencode(data) | |
| with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response: | |
| return response.read() | |
| def get_history(prompt_id): | |
| with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response: | |
| return json.loads(response.read()) | |
| def get_images(ws,prompt,progress): | |
| progress=gr.Progress(track_tqdm=True) | |
| prompt_id = queue_prompt(prompt)['prompt_id'] | |
| output_images = {} | |
| last_reported_percentage = 0 | |
| while True: | |
| out = ws.recv() | |
| if isinstance(out, str): | |
| message = json.loads(out) | |
| if message['type'] == 'progress': | |
| data = message['data'] | |
| current_progress = data['value'] | |
| max_progress = data['max'] | |
| percentage = int((current_progress / max_progress) * 100) | |
| if percentage >= last_reported_percentage + 10: | |
| last_reported_percentage = percentage | |
| progress(percentage/100) | |
| elif message['type'] == 'executing': | |
| data = message['data'] | |
| if data['node'] is None and data['prompt_id'] == prompt_id: | |
| break # Execution is done | |
| else: | |
| continue # Previews are binary data | |
| history = get_history(prompt_id)[prompt_id] | |
| for o in history['outputs']: | |
| for node_id in history['outputs']: | |
| node_output = history['outputs'][node_id] | |
| if 'images' in node_output: | |
| images_output = [] | |
| for image in node_output['images']: | |
| image_data = get_image(image['filename'], image['subfolder'], image['type']) | |
| images_output.append(image_data) | |
| output_images[node_id] = images_output | |
| return output_images | |
| def pil_to_base64(image): | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| base64_string=base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{base64_string}" | |
| def generate_images(positive_prompt,image,progress): | |
| ws = websocket.WebSocket() | |
| ws_url = f"ws://{server_address}/ws?clientId={client_id}" | |
| ws.connect(ws_url) | |
| data = json.loads(json_data) | |
| data["49"]["inputs"]["text"] = positive_prompt | |
| if image: | |
| data["90"]["inputs"]["images"]["base64"] = [pil_to_base64(image)] | |
| else: | |
| data.pop("90", None) | |
| data.pop("68", None) | |
| data["62"]["inputs"]["images"] = ["61",0] | |
| seed = random.randint(1, 1000000000) | |
| data["47"]["inputs"]["noise_seed"] = seed | |
| images = get_images(ws,data,progress) | |
| ws.close() | |
| for node_id in images: | |
| for image_data in images[node_id]: | |
| image = Image.open(io.BytesIO(image_data)) | |
| return image | |