File size: 3,581 Bytes
aac338f
fca3535
 
5762d0e
fca3535
5762d0e
b44117d
 
 
4ff15b7
5762d0e
b44117d
76e7d38
5762d0e
2b9811d
 
 
7e9b981
 
 
0824d60
2b9811d
7e9b981
5762d0e
b44117d
5762d0e
b44117d
5762d0e
b44117d
2b9811d
b44117d
5762d0e
c61580a
7e9b981
 
0824d60
5762d0e
 
 
 
 
b44117d
5762d0e
b44117d
0824d60
 
5762d0e
b44117d
5762d0e
 
2b9811d
 
5762d0e
 
 
bbc9212
5762d0e
2b9811d
5762d0e
 
 
7e9b981
bbc9212
7e9b981
a51b33a
0824d60
5762d0e
 
 
 
 
 
 
 
 
 
 
3b0a2ee
5762d0e
c61580a
3b0a2ee
5762d0e
 
 
7e9b981
5762d0e
 
 
 
 
 
2b9811d
 
5762d0e
 
 
2b9811d
 
5762d0e
2b9811d
7e9b981
5762d0e
 
b44117d
 
5762d0e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from random import randint
from all_models import models

from externalmod import gr_Interface_load, randomize_seed

import asyncio
import os
from threading import RLock

# Create a lock to ensure thread safety when accessing shared resources
lock = RLock()

# Load Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")

# Function to load models
def load_fn(models):
    global models_load
    models_load = {}
    for model in models:
        if model not in models_load:
            try:
                print(f"Attempting to load model: {model}")
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                print(f"Successfully loaded model: {model}")
            except Exception as error:
                print(f"Error loading model {model}: {error}")
                m = gr.Interface(lambda: None, ['text'], ['image'])
            models_load[model] = m

# Load the models
print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = 6

default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
starting_seed = randint(1941, 2024)

print(f"Starting seed: {starting_seed}")

def extend_choices(choices):
    return choices[:num_models] + ['NA'] * (num_models - len(choices))

# Asynchronous function for inference
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
    if model_str == 'NA':
        return None
    
    print(f"Starting inference for model: {model_str} with prompt: '{prompt}' and seed: {seed}")
    
    try:
        result = await asyncio.to_thread(models_load[model_str].fn, prompt, seed=seed, token=HF_TOKEN)
        if result:
            return result
    except (Exception, asyncio.TimeoutError) as e:
        print(f"Error during inference for model {model_str}: {e}")
    return None

def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA':
        return None
    
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        result = loop.run_until_complete(infer(model_str, prompt, seed))
    except Exception as e:
        print(f"Error during generation for model {model_str}: {e}")
        result = None
    finally:
        loop.close()
    return result

# Creating the Gradio UI
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Compare-6</h1></center>")
    
    with gr.Tab('Compare-6'):
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate up to 6 images')
        
        with gr.Row():
            seed = gr.Slider("Seed", 0, MAX_SEED, step=1, value=starting_seed)
            seed_rand = gr.Button("Randomize Seed 🎲")
        
        seed_rand.click(randomize_seed, None, [seed])
        
        with gr.Row():
            output = [gr.Image(label=m, min_width=480) for m in default_models]
            current_models = [gr.Textbox(m, visible=False) for m in default_models]
        
        for m, o in zip(current_models, output):
            gen_button.click(fn=gen_fnseed, inputs=[m, txt_input, seed], outputs=[o])
        
        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
            model_choice.change(lambda c: extend_choices(c), model_choice, current_models)

print("Launching Gradio interface...")
demo.queue(default_concurrency_limit=50, max_size=100)
demo.launch(share=True, max_threads=50)