| import gradio as gr |
| import time |
| from typing import Dict, List, Optional, Callable |
|
|
| class MultiModelImageGenerator: |
| """ |
| ## Multi-Model Stable Diffusion Image Generation Framework |
| |
| ### Core Design Principles |
| - Flexible model loading and management |
| - Concurrent image generation support |
| - Robust error handling |
| - Configurable generation strategies |
| |
| ### Technical Components |
| - Dynamic model function registration |
| - Fallback mechanism for model loading |
| - Task tracking and management |
| """ |
| |
| def __init__( |
| self, |
| models: List[str], |
| default_model_path: str = 'models/' |
| ): |
| """ |
| Initialize multi-model image generation system. |
| |
| Args: |
| models (List[str]): List of model paths for image generation |
| default_model_path (str): Base path for model loading |
| """ |
| self.models = models |
| self.default_model_path = default_model_path |
| self.model_functions: Dict[int, Callable] = {} |
| self._initialize_models() |
| |
| def _initialize_models(self): |
| """ |
| Load and initialize image generation models with fallback mechanism. |
| |
| Strategy: |
| - Attempt to load each model |
| - Provide default no-op function if loading fails |
| """ |
| for model_idx, model_path in enumerate(self.models, 1): |
| try: |
| |
| model_fn = gr.Interface.load( |
| f"{self.default_model_path}{model_path}", |
| live=False, |
| preprocess=True, |
| postprocess=False |
| ) |
| self.model_functions[model_idx] = model_fn |
| except Exception as error: |
| |
| def fallback_fn(txt): |
| return None |
| |
| self.model_functions[model_idx] = gr.Interface( |
| fn=fallback_fn, |
| inputs=["text"], |
| outputs=["image"] |
| ) |
| |
| def generate_with_model( |
| self, |
| model_idx: int, |
| prompt: str |
| ) -> Optional[gr.Image]: |
| """ |
| Generate image using specified model with intelligent fallback. |
| |
| Args: |
| model_idx (int): Index of model to use |
| prompt (str): Generation prompt |
| |
| Returns: |
| Generated image or None if generation fails |
| """ |
| |
| selected_model = ( |
| self.model_functions.get(str(model_idx)) or |
| self.model_functions.get(str(1)) |
| ) |
| |
| return selected_model(prompt) |
| |
| def create_gradio_interface(self) -> gr.Blocks: |
| """ |
| Create Gradio interface for multi-model image generation. |
| |
| Returns: |
| Configurable Gradio Blocks interface |
| """ |
| with gr.Blocks(title="Multi-Model Stable Diffusion", theme="Nymbo/Nymbo_Theme") as interface: |
| with gr.Column(scale=12): |
| with gr.Row(): |
| primary_prompt = gr.Textbox(label="Generation Prompt", value="") |
| |
| with gr.Row(): |
| run_btn = gr.Button("Generate", variant="primary") |
| clear_btn = gr.Button("Clear") |
| |
| |
| sd_outputs = {} |
| for model_idx, model_path in enumerate(self.models, 1): |
| with gr.Column(scale=3, min_width=320): |
| with gr.Box(): |
| sd_outputs[model_idx] = gr.Image(label=model_path) |
| |
| |
| with gr.Row(visible=False): |
| start_box = gr.Number(interactive=False) |
| end_box = gr.Number(interactive=False) |
| task_status_box = gr.Textbox(value=0, interactive=False) |
| |
| |
| def start_task(): |
| t_stamp = time.time() |
| return ( |
| gr.update(value=t_stamp), |
| gr.update(value=t_stamp), |
| gr.update(value=0) |
| ) |
| |
| def check_task_status(cnt, t_stamp): |
| current_time = time.time() |
| timeout = t_stamp + 60 |
| |
| if current_time > timeout and t_stamp != 0: |
| return gr.update(value=0), gr.update(value=1) |
| else: |
| return ( |
| gr.update(value=current_time if cnt != 0 else 0), |
| gr.update(value=0) |
| ) |
| |
| def clear_interface(): |
| return tuple([None] + [None] * len(self.models)) |
| |
| |
| start_box.change( |
| check_task_status, |
| [start_box, end_box], |
| [start_box, task_status_box], |
| every=1, |
| show_progress=False |
| ) |
| |
| primary_prompt.submit(start_task, None, [start_box, end_box, task_status_box]) |
| run_btn.click(start_task, None, [start_box, end_box, task_status_box]) |
| |
| |
| generation_tasks = {} |
| for model_idx, model_path in enumerate(self.models, 1): |
| generation_tasks[model_idx] = run_btn.click( |
| self.generate_with_model, |
| inputs=[gr.Number(model_idx), primary_prompt], |
| outputs=[sd_outputs[model_idx]] |
| ) |
| |
| |
| clear_btn.click( |
| clear_interface, |
| None, |
| [primary_prompt, *list(sd_outputs.values())], |
| cancels=list(generation_tasks.values()) |
| ) |
| |
| return interface |
| |
| def launch(self, **kwargs): |
| """ |
| Launch Gradio interface with configurable parameters. |
| |
| Args: |
| **kwargs: Gradio launch configuration parameters |
| """ |
| interface = self.create_gradio_interface() |
| interface.queue(concurrency_count=600, status_update_rate=0.1) |
| interface.launch(**kwargs) |
|
|
| def main(): |
| """ |
| Demonstration of Multi-Model Image Generation Framework |
| """ |
| models = [ |
| "doohickey/neopian-diffusion", |
| "dxli/duck_toy", |
| "dxli/bear_plushie", |
| "haor/Evt_V4-preview", |
| "Yntec/Dreamscapes_n_Dragonfire_v2" |
| ] |
| |
| image_generator = MultiModelImageGenerator(models) |
| image_generator.launch(inline=True, show_api=False) |
|
|
| if __name__ == "__main__": |
| main() |