Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import warnings | |
| import tempfile | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from transformers import ( | |
| CLIPTextModelWithProjection, | |
| CLIPTokenizer, | |
| ) | |
| from diffusers.models.autoencoders.vq_model import VQModel | |
| from src.transformer import SymmetricTransformer2DModel | |
| from src.pipeline import UnifiedPipeline | |
| from src.scheduler import Scheduler | |
| from train.trainer_utils import load_images_to_tensor | |
| # Suppress FutureWarnings to reduce clutter | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| # Set Gradio temp directory to a writable location | |
| def setup_gradio_temp_dir(): | |
| """Setup a writable temp directory for Gradio with fallback options""" | |
| possible_dirs = [ | |
| os.path.join(os.getcwd(), "gradio_tmp"), # Project directory | |
| os.path.join(os.path.expanduser("~"), ".gradio_tmp"), # Home directory | |
| tempfile.mkdtemp(prefix="gradio_") # System temp with unique name | |
| ] | |
| for temp_dir in possible_dirs: | |
| try: | |
| os.makedirs(temp_dir, exist_ok=True) | |
| # Test write permission | |
| test_file = os.path.join(temp_dir, "test_write.tmp") | |
| with open(test_file, "w") as f: | |
| f.write("test") | |
| os.remove(test_file) | |
| os.environ["GRADIO_TEMP_DIR"] = temp_dir | |
| print(f"β Gradio temp directory set to: {temp_dir}") | |
| return temp_dir | |
| except (PermissionError, OSError) as e: | |
| print(f"β οΈ Cannot use {temp_dir}: {e}") | |
| continue | |
| raise RuntimeError("Could not find a writable directory for Gradio temp files") | |
| setup_gradio_temp_dir() | |
| class MudditInterface: | |
| def __init__(self, model_path="MeissonFlow/Meissonic", transformer_path="QingyuShi/Muddit"): | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| self.device = device | |
| self.model_path = model_path | |
| self.transformer_path = transformer_path or model_path | |
| print("Loading models...") | |
| self.load_models() | |
| print("Models loaded successfully!") | |
| def load_models(self): | |
| """Load all required models""" | |
| try: | |
| print("π₯ Loading transformer model...") | |
| self.model = SymmetricTransformer2DModel.from_pretrained( | |
| self.transformer_path, | |
| subfolder="transformer", | |
| ) | |
| print("π₯ Loading VQ model...") | |
| self.vq_model = VQModel.from_pretrained( | |
| self.model_path, | |
| subfolder="vqvae" | |
| ) | |
| print("π₯ Loading text encoder...") | |
| self.text_encoder = CLIPTextModelWithProjection.from_pretrained( | |
| self.model_path, | |
| subfolder="text_encoder" | |
| ) | |
| print("π₯ Loading tokenizer...") | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| self.model_path, | |
| subfolder="tokenizer" | |
| ) | |
| print("π₯ Loading scheduler...") | |
| self.scheduler = Scheduler.from_pretrained( | |
| self.model_path, | |
| subfolder="scheduler" | |
| ) | |
| print("π§ Assembling pipeline...") | |
| self.pipe = UnifiedPipeline( | |
| vqvae=self.vq_model, | |
| tokenizer=self.tokenizer, | |
| text_encoder=self.text_encoder, | |
| transformer=self.model, | |
| scheduler=self.scheduler, | |
| ) | |
| print(f"π Moving models to {self.device}...") | |
| self.pipe.to(self.device) | |
| except Exception as e: | |
| print(f"β Error loading models: {str(e)}") | |
| raise | |
| def text_to_image(self, prompt, negative_prompt, height, width, steps, cfg_scale, seed): | |
| """Generate image from text prompt""" | |
| try: | |
| if seed == -1: | |
| generator = None | |
| else: | |
| generator = torch.manual_seed(seed) | |
| if not negative_prompt: | |
| negative_prompt = "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" | |
| output = self.pipe( | |
| prompt=[prompt], | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=cfg_scale, | |
| num_inference_steps=steps, | |
| mask_token_embedding=None, | |
| generator=generator | |
| ) | |
| if hasattr(output, 'images') and len(output.images) > 0: | |
| return output.images[0] | |
| else: | |
| return None | |
| except Exception as e: | |
| gr.Error(f"Error generating image: {str(e)}") | |
| return None | |
| def image_to_text(self, image, question, height, width, steps, cfg_scale): | |
| """Answer question about the image""" | |
| try: | |
| if image is None: | |
| return "Please upload an image." | |
| # Convert PIL image to tensor | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Save image temporarily and load using the existing function | |
| temp_path = "temp_image.jpg" | |
| image.save(temp_path) | |
| try: | |
| images = load_images_to_tensor(temp_path, target_size=(height, width)) | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| if images is None: | |
| return "Failed to process the image." | |
| questions = [question] * len(images) | |
| output = self.pipe( | |
| prompt=questions, | |
| image=images, | |
| height=height, | |
| width=width, | |
| guidance_scale=cfg_scale, | |
| num_inference_steps=steps, | |
| mask_token_embedding=None, | |
| ) | |
| if hasattr(output, 'prompts') and len(output.prompts) > 0: | |
| return output.prompts[0] | |
| else: | |
| return "No response generated." | |
| except Exception as e: | |
| return f"Error processing image: {str(e)}" | |
| def create_muddit_interface(): | |
| # Initialize the model interface | |
| interface = MudditInterface() | |
| with gr.Blocks(title="Muddit Interface", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π¨ Muddit Interface") | |
| gr.Markdown("Generate images from text or ask questions about images using Muddit.") | |
| with gr.Tabs(): | |
| # Text-to-Image Tab | |
| with gr.TabItem("πΌοΈ Text-to-Image"): | |
| gr.Markdown("### Generate images from text descriptions") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| t2i_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars", | |
| lines=3 | |
| ) | |
| t2i_negative = gr.Textbox( | |
| label="Negative Prompt (optional)", | |
| placeholder="worst quality, low quality, blurry...", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| t2i_width = gr.Slider( | |
| minimum=256, maximum=1024, value=1024, step=64, | |
| label="Width" | |
| ) | |
| t2i_height = gr.Slider( | |
| minimum=256, maximum=1024, value=1024, step=64, | |
| label="Height" | |
| ) | |
| with gr.Row(): | |
| t2i_steps = gr.Slider( | |
| minimum=1, maximum=100, value=64, step=1, | |
| label="Inference Steps" | |
| ) | |
| t2i_cfg = gr.Slider( | |
| minimum=1.0, maximum=20.0, value=9.0, step=0.5, | |
| label="CFG Scale" | |
| ) | |
| t2i_seed = gr.Number( | |
| label="Seed (-1 for random)", | |
| value=42, | |
| precision=0 | |
| ) | |
| t2i_generate = gr.Button("π¨ Generate Image", variant="primary") | |
| with gr.Column(scale=1): | |
| t2i_output = gr.Image(label="Generated Image", type="pil") | |
| t2i_generate.click( | |
| fn=interface.text_to_image, | |
| inputs=[t2i_prompt, t2i_negative, t2i_height, t2i_width, t2i_steps, t2i_cfg, t2i_seed], | |
| outputs=[t2i_output] | |
| ) | |
| # Visual Question Answering Tab | |
| with gr.TabItem("β Visual Question Answering"): | |
| gr.Markdown("### Ask questions about images") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vqa_image = gr.Image( | |
| label="Upload Image", | |
| type="pil" | |
| ) | |
| vqa_question = gr.Textbox( | |
| label="Question", | |
| placeholder="What do you see in this image?", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| vqa_width = gr.Slider( | |
| minimum=256, maximum=1024, value=1024, step=64, | |
| label="Width" | |
| ) | |
| vqa_height = gr.Slider( | |
| minimum=256, maximum=1024, value=1024, step=64, | |
| label="Height" | |
| ) | |
| with gr.Row(): | |
| vqa_steps = gr.Slider( | |
| minimum=1, maximum=100, value=64, step=1, | |
| label="Inference Steps" | |
| ) | |
| vqa_cfg = gr.Slider( | |
| minimum=1.0, maximum=20.0, value=9.0, step=0.5, | |
| label="CFG Scale" | |
| ) | |
| vqa_submit = gr.Button("π€ Ask Question", variant="primary") | |
| with gr.Column(scale=1): | |
| vqa_output = gr.Textbox( | |
| label="Answer", | |
| lines=5, | |
| interactive=False | |
| ) | |
| vqa_submit.click( | |
| fn=interface.image_to_text, | |
| inputs=[vqa_image, vqa_question, vqa_height, vqa_width, vqa_steps, vqa_cfg], | |
| outputs=[vqa_output] | |
| ) | |
| # Example section | |
| with gr.Accordion("π Examples", open=False): | |
| gr.Markdown(""" | |
| ### Text-to-Image Examples: | |
| - "A majestic night sky awash with billowing clouds, sparkling with a million twinkling stars" | |
| - "A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head" | |
| - "A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear" | |
| ### VQA Examples: | |
| - "What objects do you see in this image?" | |
| - "How many people are in the picture?" | |
| - "What is the main subject of this image?" | |
| - "Describe the scene in detail" | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_muddit_interface() | |
| demo.launch() |