| | import gradio as gr |
| | import numpy as np |
| | import random |
| | import multiprocessing |
| | import subprocess |
| | import sys |
| | import time |
| | import signal |
| | import json |
| | import os |
| | import requests |
| |
|
| | from loguru import logger |
| | from decouple import config |
| |
|
| | from pathlib import Path |
| | from PIL import Image |
| | import io |
| |
|
| | URL="http://127.0.0.1" |
| | OUTPUT_DIR = config('OUTPUT_DIR') |
| | INPUT_DIR = config('INPUT_DIR') |
| | COMF_PATH = config('COMF_PATH') |
| |
|
| | import torch |
| |
|
| | import spaces |
| |
|
| | print(f"Is CUDA available: {torch.cuda.is_available()}") |
| | print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
| | print(torch.version.cuda) |
| | device = torch.cuda.get_device_name(torch.cuda.current_device()) |
| | print(device) |
| |
|
| |
|
| | def wait_for_image_with_prefix(folder, prefix): |
| | def is_file_ready(file_path): |
| | initial_size = os.path.getsize(file_path) |
| | time.sleep(1) |
| | return initial_size == os.path.getsize(file_path) |
| |
|
| |
|
| | files = os.listdir(folder) |
| | image_files = [f for f in files if f.lower().startswith(prefix.lower()) and |
| | f.lower().endswith(('.png', '.jpg', '.jpeg'))] |
| |
|
| | if image_files: |
| | |
| | image_files.sort(key=lambda x: os.path.getmtime(os.path.join(folder, x)), reverse=True) |
| | latest_image = os.path.join(folder, image_files[0]) |
| |
|
| | if is_file_ready(latest_image): |
| | |
| | time.sleep(3) |
| | return latest_image |
| |
|
| | |
| | return None |
| |
|
| |
|
| | def delete_image_file(file_path): |
| | try: |
| | if os.path.exists(file_path): |
| | os.remove(file_path) |
| | logger.debug(f"file {file_path} deleted") |
| | else: |
| | logger.debug(f"file {file_path} is not exist") |
| | except Exception as e: |
| | logger.debug(f"error {file_path}: {str(e)}") |
| |
|
| |
|
| | def start_queue(prompt_workflow, port): |
| | p = {"prompt": prompt_workflow} |
| | data = json.dumps(p).encode('utf-8') |
| | requests.post(f"{URL}:{port}/prompt", data=data) |
| |
|
| |
|
| | def check_server_ready(port): |
| | try: |
| | response = requests.get(f"{URL}:{port}/history/123", timeout=5) |
| | return response.status_code == 200 |
| | except requests.RequestException: |
| | return False |
| |
|
| |
|
| |
|
| | @spaces.GPU(duration=240) |
| | def generate_image(prompt, image, image2): |
| | prefix_filename = str(random.randint(0, 999999)) |
| | prompt = prompt.replace('ComfyUI', prefix_filename) |
| | prompt = json.loads(prompt) |
| |
|
| | image = Image.fromarray(image) |
| | image.save(INPUT_DIR + '/input.png', format='PNG') |
| | if image2 is not None: |
| | image2 = Image.fromarray(image2) |
| | image2.save(INPUT_DIR + '/input2.png', format='PNG') |
| |
|
| | process = None |
| | new_port = str(random.randint(8123, 8200)) |
| |
|
| | try: |
| | |
| | process = subprocess.Popen([sys.executable, COMF_PATH, "--listen", "127.0.0.1", "--port", new_port]) |
| | logger.debug(f'Subprocess started with PID: {process.pid}') |
| |
|
| | |
| | for _ in range(30): |
| | if check_server_ready(new_port): |
| | break |
| | time.sleep(1) |
| | else: |
| | raise TimeoutError("Server did not start in time") |
| |
|
| | start_queue(prompt, new_port) |
| |
|
| | |
| | timeout = 240 |
| | start_time = time.time() |
| | while time.time() - start_time < timeout: |
| | latest_image = wait_for_image_with_prefix(OUTPUT_DIR, prefix_filename) |
| | if latest_image: |
| | logger.debug(f"file is: {latest_image}") |
| | try: |
| | return Image.open(latest_image) |
| | finally: |
| | delete_image_file(latest_image) |
| | delete_image_file(INPUT_DIR + '/input.png') |
| | if image2 is not None: |
| | delete_image_file(INPUT_DIR + '/input2.png') |
| | time.sleep(1) |
| |
|
| | raise TimeoutError("New image was not generated in time") |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in generate_image: {e}") |
| |
|
| | finally: |
| | if process and process.poll() is None: |
| | process.terminate() |
| | logger.debug("process.terminate()") |
| | try: |
| | logger.debug("process.wait(timeout=5)") |
| | process.wait(timeout=5) |
| | except subprocess.TimeoutExpired: |
| | logger.debug("process.kill()") |
| | process.kill() |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo = gr.Interface(fn=generate_image, |
| | inputs=[ |
| | "text", |
| | gr.Image(image_mode='RGBA', type="numpy"), |
| | gr.Image(image_mode='RGBA', type="numpy") |
| | ], |
| | outputs=[ |
| | gr.Image(type="numpy", image_mode='RGBA') |
| | ] |
| | ) |
| | demo.launch(debug=True) |
| | logger.debug('demo.launch()') |
| |
|
| | logger.info("Основной скрипт завершил работу.") |