Spaces:
Runtime error
Runtime error
| import spaces # beginn | |
| import torch.multiprocessing as mp | |
| import torch | |
| import os | |
| import pandas as pd | |
| import gc | |
| import re | |
| import random | |
| from tqdm.auto import tqdm | |
| from collections import deque | |
| from optimum.quanto import freeze, qfloat8, quantize | |
| from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
| from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
| from diffusers.pipelines.flux.pipeline_flux import FluxPipeline | |
| from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| from accelerate import Accelerator | |
| # Check if the start method has already been set | |
| if mp.get_start_method(allow_none=True) != 'spawn': | |
| mp.set_start_method('spawn') | |
| # Instantiate the Accelerator | |
| accelerator = Accelerator() | |
| dtype = torch.bfloat16 | |
| # Set environment variables for local path | |
| os.environ['FLUX_DEV'] = '.' | |
| os.environ['AE'] = '.' | |
| bfl_repo = 'black-forest-labs/FLUX.1-schnell' | |
| revision = 'refs/pr/1' | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision) | |
| text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
| tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
| text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision) | |
| tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision) | |
| vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision) | |
| transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision) | |
| quantize(transformer, weights=qfloat8) | |
| freeze(transformer) | |
| quantize(text_encoder_2, weights=qfloat8) | |
| freeze(text_encoder_2) | |
| pipe = FluxPipeline( | |
| scheduler=scheduler, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| text_encoder_2=None, | |
| tokenizer_2=tokenizer_2, | |
| vae=vae, | |
| transformer=None, | |
| ) | |
| pipe.text_encoder_2 = text_encoder_2 | |
| pipe.transformer = transformer | |
| pipe.enable_model_cpu_offload() | |
| # Create a directory to save the generated images | |
| output_dir = 'generated_images' | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Function to generate a detailed visual description prompt | |
| def generate_description_prompt(subject, user_prompt, text_generator): | |
| prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. " | |
| try: | |
| generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text'] | |
| generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text | |
| return generated_description if generated_description else None | |
| except Exception as e: | |
| print(f"Error generating description for subject '{subject}': {e}") | |
| return None | |
| # Function to parse descriptions from a given text | |
| def parse_descriptions(text): | |
| # Find all descriptions enclosed in brackets | |
| descriptions = re.findall(r'\[([^\[\]]+)\]', text) | |
| # Filter descriptions with at least 3 words | |
| descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3] | |
| return descriptions | |
| # Seed words pool | |
| seed_words = [] | |
| used_words = set() | |
| # Queue to store parsed descriptions | |
| parsed_descriptions_queue = deque() | |
| # Usage limits | |
| MAX_DESCRIPTIONS = 30 | |
| MAX_IMAGES = 12 | |
| def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50): | |
| descriptions = [] | |
| description_queue = deque() | |
| iteration_count = 0 | |
| # Initialize the text generation pipeline with 16-bit precision | |
| print("Initializing the text generation pipeline with 16-bit precision...") | |
| model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct' | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
| print("Text generation pipeline initialized with 16-bit precision.") | |
| # Populate the seed_words array with user input | |
| seed_words.extend(re.findall(r'"(.*?)"', seed_words_input)) | |
| while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS: | |
| # Select a subject that has not been used | |
| available_subjects = [word for word in seed_words if word not in used_words] | |
| if not available_subjects: | |
| print("No more available subjects to use.") | |
| break | |
| subject = random.choice(available_subjects) | |
| generated_description = generate_description_prompt(subject, user_prompt, text_generator) | |
| if generated_description: | |
| # Remove any offending symbols | |
| clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') | |
| description_queue.append({'subject': subject, 'description': clean_description}) | |
| # Print the generated description to the command line | |
| print(f"Generated description for subject '{subject}': {clean_description}") | |
| # Update used words and seed words | |
| used_words.add(subject) | |
| seed_words.append(clean_description) # Add the generated description to the seed bank array | |
| # Parse and append descriptions every 3 iterations | |
| if iteration_count % 3 == 0: | |
| parsed_descriptions = parse_descriptions(clean_description) | |
| parsed_descriptions_queue.extend(parsed_descriptions) | |
| iteration_count += 1 | |
| return list(parsed_descriptions_queue) | |
| def generate_images(parsed_descriptions): | |
| # If there are fewer than MAX_IMAGES descriptions, use whatever is available | |
| if len(parsed_descriptions) < MAX_IMAGES: | |
| prompts = parsed_descriptions | |
| else: | |
| prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)] | |
| # Generate images from the parsed descriptions | |
| images = [] | |
| for prompt in prompts: | |
| images.extend(pipe(prompt, num_images=1).images) | |
| return images | |
| # Create Gradio Interface | |
| def combined_function(user_prompt, seed_words_input): | |
| parsed_descriptions = generate_descriptions(user_prompt, seed_words_input) | |
| images = generate_images(parsed_descriptions) | |
| return images | |
| if __name__ == '__main__': | |
| # Ensure CUDA is initialized correctly | |
| torch.cuda.init() | |
| interface = gr.Interface( | |
| fn=combined_function, | |
| inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')], | |
| outputs=gr.Gallery() | |
| ) | |
| interface.launch() |