Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from flux.transformer_flux import FluxTransformer2DModel | |
| from flux.pipeline_flux_chameleon import FluxPipeline | |
| import torch.nn as nn | |
| import math | |
| import logging | |
| import sys | |
| from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| # 设置日志 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| MODEL_ID = "Djrango/Qwen2vl-Flux" | |
| MODEL_CACHE_DIR = "model_cache" | |
| # 预下载所有模型 | |
| def download_models(): | |
| logger.info("Starting model download...") | |
| try: | |
| # 下载完整模型仓库 | |
| snapshot_download( | |
| repo_id=MODEL_ID, | |
| local_dir=MODEL_CACHE_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info("Model download completed successfully") | |
| except Exception as e: | |
| logger.error(f"Error downloading models: {str(e)}") | |
| raise | |
| # 在脚本开始时下载模型 | |
| if not os.path.exists(MODEL_CACHE_DIR): | |
| download_models() | |
| # Add aspect ratio options | |
| ASPECT_RATIOS = { | |
| "1:1": (1024, 1024), | |
| "16:9": (1344, 768), | |
| "9:16": (768, 1344), | |
| "2.4:1": (1536, 640), | |
| "3:4": (896, 1152), | |
| "4:3": (1152, 896), | |
| } | |
| class Qwen2Connector(nn.Module): | |
| def __init__(self, input_dim=3584, output_dim=4096): | |
| super().__init__() | |
| self.linear = nn.Linear(input_dim, output_dim) | |
| def forward(self, x): | |
| return self.linear(x) | |
| class FluxInterface: | |
| def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): | |
| self.device = device | |
| self.dtype = torch.bfloat16 | |
| self.models = None | |
| self.MODEL_ID = "Djrango/Qwen2vl-Flux" | |
| def load_models(self): | |
| if self.models is not None: | |
| return | |
| logger.info("Starting model loading...") | |
| # 3. 显式设置 PyTorch 缓存分配器的行为 | |
| torch.cuda.set_per_process_memory_fraction(0.95) # 允许使用95%的显存 | |
| torch.cuda.max_memory_allocated = lambda *args, **kwargs: 0 # 忽略已分配内存的限制 | |
| # Load FLUX components | |
| tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer")) | |
| text_encoder = CLIPTextModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")).to(self.dtype).to(self.device) | |
| text_encoder_two = T5EncoderModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")).to(self.dtype).to(self.device) | |
| tokenizer_two = T5TokenizerFast.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2")) | |
| # Load VAE and transformer | |
| vae = AutoencoderKL.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/vae")).to(self.dtype).to(self.device) | |
| transformer = FluxTransformer2DModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/transformer")).to(self.dtype).to(self.device) | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/scheduler"), shift=1) | |
| # Load Qwen2VL components | |
| qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(os.path.join(MODEL_CACHE_DIR, "qwen2-vl")).to(self.dtype).to(self.device) | |
| # 加载 connector | |
| connector = Qwen2Connector().to(self.dtype).to(self.device) | |
| connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt") | |
| connector_state = torch.load(connector_path, map_location='cpu') | |
| connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()} | |
| connector.load_state_dict(connector_state) | |
| connector = connector.to(self.device) | |
| # 加载 T5 embedder | |
| self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device) | |
| t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt") | |
| t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu') | |
| t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()} | |
| self.t5_context_embedder.load_state_dict(t5_embedder_state) | |
| self.t5_context_embedder = self.t5_context_embedder.to(self.device) | |
| # Set models to eval mode | |
| for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]: | |
| model.requires_grad_(False) | |
| model.eval() | |
| logger.info("All models loaded successfully") | |
| self.models = { | |
| 'tokenizer': tokenizer, | |
| 'text_encoder': text_encoder, | |
| 'text_encoder_two': text_encoder_two, | |
| 'tokenizer_two': tokenizer_two, | |
| 'vae': vae, | |
| 'transformer': transformer, | |
| 'scheduler': scheduler, | |
| 'qwen2vl': qwen2vl, | |
| 'connector': connector | |
| } | |
| # Initialize processor and pipeline | |
| self.qwen2vl_processor = AutoProcessor.from_pretrained( | |
| self.MODEL_ID, | |
| subfolder="qwen2-vl", | |
| min_pixels=256*28*28, | |
| max_pixels=256*28*28 | |
| ) | |
| self.pipeline = FluxPipeline( | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| ) | |
| def resize_image(self, img, max_pixels=1050000): | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(img) | |
| width, height = img.size | |
| num_pixels = width * height | |
| if num_pixels > max_pixels: | |
| scale = math.sqrt(max_pixels / num_pixels) | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| new_width = new_width - (new_width % 8) | |
| new_height = new_height - (new_height % 8) | |
| img = img.resize((new_width, new_height), Image.LANCZOS) | |
| return img | |
| # [Previous methods remain unchanged...] | |
| def process_image(self, image): | |
| message = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": "Describe this image."}, | |
| ] | |
| } | |
| ] | |
| text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) | |
| with torch.no_grad(): | |
| inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device) | |
| output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs) | |
| image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1)) | |
| image_hidden_state = self.models['connector'](image_hidden_state) | |
| return image_hidden_state, image_grid_thw | |
| def compute_t5_text_embeddings(self, prompt): | |
| """Compute T5 embeddings for text prompt""" | |
| if prompt == "": | |
| return None | |
| text_inputs = self.models['tokenizer_two']( | |
| prompt, | |
| padding="max_length", | |
| max_length=256, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| prompt_embeds = self.models['text_encoder_two'](text_inputs.input_ids)[0] | |
| prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device) | |
| prompt_embeds = self.t5_context_embedder(prompt_embeds) | |
| return prompt_embeds | |
| def compute_text_embeddings(self, prompt=""): | |
| with torch.no_grad(): | |
| text_inputs = self.models['tokenizer']( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| prompt_embeds = self.models['text_encoder']( | |
| text_inputs.input_ids, | |
| output_hidden_states=False | |
| ) | |
| pooled_prompt_embeds = prompt_embeds.pooler_output.to(self.dtype) | |
| return pooled_prompt_embeds | |
| # 300秒的 GPU 使用时间 | |
| def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"): | |
| try: | |
| logger.info(f"Starting generation with prompt: {prompt}, guidance_scale: {guidance_scale}, steps: {num_inference_steps}") | |
| if input_image is None: | |
| raise ValueError("No input image provided") | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| logger.info(f"Set random seed to: {seed}") | |
| self.load_models() | |
| logger.info("Models loaded successfully") | |
| # Get dimensions from aspect ratio | |
| if aspect_ratio not in ASPECT_RATIOS: | |
| raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}") | |
| width, height = ASPECT_RATIOS[aspect_ratio] | |
| logger.info(f"Using dimensions: {width}x{height}") | |
| # Process input image | |
| try: | |
| input_image = self.resize_image(input_image) | |
| logger.info(f"Input image resized to: {input_image.size}") | |
| qwen2_hidden_state, image_grid_thw = self.process_image(input_image) | |
| logger.info("Input image processed successfully") | |
| except Exception as e: | |
| raise RuntimeError(f"Error processing input image: {str(e)}") | |
| try: | |
| pooled_prompt_embeds = self.compute_text_embeddings("") | |
| logger.info("Base text embeddings computed") | |
| # Get T5 embeddings if prompt is provided | |
| t5_prompt_embeds = self.compute_t5_text_embeddings(prompt) | |
| logger.info("T5 prompt embeddings computed") | |
| except Exception as e: | |
| raise RuntimeError(f"Error computing embeddings: {str(e)}") | |
| # Generate images | |
| try: | |
| output_images = self.pipeline( | |
| prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1), | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=height, | |
| width=width, | |
| ).images | |
| logger.info("Images generated successfully") | |
| return output_images | |
| except Exception as e: | |
| raise RuntimeError(f"Error generating images: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Error during generation: {str(e)}") | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| # Initialize the interface | |
| interface = FluxInterface() | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .container { | |
| max-width: 1200px; | |
| margin: auto; | |
| padding: 0 20px; | |
| } | |
| .header { | |
| text-align: center; | |
| margin: 20px 0 40px 0; | |
| padding: 20px; | |
| background: #f7f7f7; | |
| border-radius: 12px; | |
| } | |
| .param-row { | |
| padding: 10px 0; | |
| } | |
| footer { | |
| margin-top: 40px; | |
| padding: 20px; | |
| border-top: 1px solid #eee; | |
| } | |
| """ | |
| ) as demo: | |
| with gr.Column(elem_classes="container"): | |
| gr.Markdown( | |
| """ | |
| <div class="header"> | |
| # 🎨 Qwen2vl-Flux Image Variation Demo | |
| Generate creative variations of your images with optional text guidance | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| # Input Section | |
| input_image = gr.Image( | |
| label="Upload Your Image", | |
| type="pil", | |
| height=384, | |
| sources=["upload", "clipboard"] | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Group(): | |
| prompt = gr.Textbox( | |
| label="Text Prompt (Optional)", | |
| placeholder="As Long As Possible...", | |
| lines=3 | |
| ) | |
| with gr.Row(elem_classes="param-row"): | |
| guidance = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3.5, | |
| step=0.5, | |
| label="Guidance Scale", | |
| info="Higher values follow prompt more closely" | |
| ) | |
| steps = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=28, | |
| step=1, | |
| label="Sampling Steps", | |
| info="More steps = better quality but slower" | |
| ) | |
| with gr.Row(elem_classes="param-row"): | |
| num_images = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| value=2, | |
| step=1, | |
| label="Number of Images", | |
| info="Generate multiple variations at once" | |
| ) | |
| seed = gr.Number( | |
| label="Random Seed", | |
| value=None, | |
| precision=0, | |
| info="Set for reproducible results" | |
| ) | |
| aspect_ratio = gr.Radio( | |
| label="Aspect Ratio", | |
| choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"], | |
| value="1:1", | |
| info="Choose aspect ratio for generated images" | |
| ) | |
| submit_btn = gr.Button( | |
| "🎨 Generate Variations", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| # Output Section | |
| output_gallery = gr.Gallery( | |
| label="Generated Variations", | |
| columns=2, | |
| rows=2, | |
| height=700, | |
| object_fit="contain", | |
| show_label=True, | |
| allow_preview=True, | |
| preview=True | |
| ) | |
| error_message = gr.Textbox(visible=False) | |
| with gr.Row(elem_classes="footer"): | |
| gr.Markdown(""" | |
| ### Tips: | |
| - 📸 Upload any image to get started | |
| - 💡 Add an optional text prompt to guide the generation | |
| - 🎯 Adjust guidance scale to control prompt influence | |
| - ⚙️ Increase steps for higher quality | |
| - 🎲 Use seeds for reproducible results | |
| """) | |
| # Set up the generation function | |
| def generate_with_error_handling(*args): | |
| try: | |
| logger.info("Starting image generation with args: %s", str(args)) | |
| # 输入参数验证 | |
| input_image, prompt, guidance, steps, num_images, seed, aspect_ratio = args | |
| logger.info(f"Input validation - Image: {type(input_image)}, Prompt: '{prompt}', " | |
| f"Guidance: {guidance}, Steps: {steps}, Num Images: {num_images}, " | |
| f"Seed: {seed}, Aspect Ratio: {aspect_ratio}") | |
| if input_image is None: | |
| raise ValueError("No input image provided") | |
| gr.Info("Starting image generation...") | |
| results = interface.generate(*args) | |
| logger.info("Generation completed successfully") | |
| gr.Info("Generation complete!") | |
| return [results, None] | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"Error in generate_with_error_handling: {error_msg}", exc_info=True) | |
| return [None, error_msg] | |
| submit_btn.click( | |
| fn=generate_with_error_handling, | |
| inputs=[ | |
| input_image, | |
| prompt, | |
| guidance, | |
| steps, | |
| num_images, | |
| seed, | |
| aspect_ratio | |
| ], | |
| outputs=[ | |
| output_gallery, | |
| error_message | |
| ], | |
| show_progress=True | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", # Listen on all network interfaces | |
| server_port=7860, # Use a specific port | |
| share=False, # Disable public URL sharing | |
| ) |