Spaces:
Runtime error
Runtime error
| import os | |
| import pandas as pd | |
| import torch | |
| import gc | |
| import re | |
| import random | |
| from tqdm.auto import tqdm | |
| from collections import deque | |
| from optimum.quanto import freeze, qfloat8, quantize | |
| from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
| from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
| from diffusers.pipelines.flux.pipeline_flux import FluxPipeline | |
| from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| from accelerate import Accelerator | |
| import huggingface_hub # Ensure this import is correct | |
| # Instantiate the Accelerator | |
| accelerator = Accelerator() | |
| dtype = torch.bfloat16 | |
| # Set environment variables for local path | |
| os.environ['FLUX_DEV'] = '.' | |
| os.environ['AE'] = '.' | |
| bfl_repo = 'black-forest-labs/FLUX.1-schnell' | |
| revision = 'refs/pr/1' | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision) | |
| text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
| tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
| text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision) | |
| tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision) | |
| vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision) | |
| transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision) | |
| quantize(transformer, weights=qfloat8) | |
| freeze(transformer) | |
| quantize(text_encoder_2, weights=qfloat8) | |
| freeze(text_encoder_2) | |
| pipe = FluxPipeline( | |
| scheduler=scheduler, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| text_encoder_2=None, | |
| tokenizer_2=tokenizer_2, | |
| vae=vae, | |
| transformer=None, | |
| ) | |
| pipe.text_encoder_2 = text_encoder_2 | |
| pipe.transformer = transformer | |
| pipe.enable_model_cpu_offload() | |
| # Create a directory to save the generated images | |
| output_dir = 'generated_images' | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Function to generate a detailed visual description prompt | |
| def generate_description_prompt(subject, user_prompt, text_generator): | |
| prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. " | |
| try: | |
| generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text'] | |
| generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text | |
| return generated_description if generated_description else None | |
| except Exception as e: | |
| print(f"Error generating description for subject '{subject}': {e}") | |
| return None | |
| # Function to parse descriptions from a given text | |
| def parse_descriptions(text): | |
| # Find all descriptions enclosed in brackets | |
| descriptions = re.findall(r'\[([^\[\]]+)\]', text) | |
| # Filter descriptions with at least 3 words | |
| descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3] | |
| return descriptions | |
| # Seed words pool | |
| seed_words = [] | |
| used_words = set() | |
| paused = False | |
| # Queue to store parsed descriptions | |
| parsed_descriptions_queue = deque() | |
| def generate_and_store_descriptions(user_prompt, batch_size=100, max_iterations=50): | |
| global paused | |
| descriptions = [] | |
| description_queue = deque() | |
| iteration_count = 0 | |
| # Initialize the text generation pipeline with 16-bit precision | |
| print("Initializing the text generation pipeline with 16-bit precision...") | |
| model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct' | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto') | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
| print("Text generation pipeline initialized with 16-bit precision.") | |
| while iteration_count < max_iterations: | |
| if paused: | |
| break | |
| # Select a subject that has not been used | |
| available_subjects = [word for word in seed_words if word not in used_words] | |
| if not available_subjects: | |
| print("No more available subjects to use.") | |
| break | |
| subject = random.choice(available_subjects) | |
| generated_description = generate_description_prompt(subject, user_prompt, text_generator) | |
| if generated_description: | |
| # Remove any offending symbols | |
| clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') | |
| description_queue.append({'subject': subject, 'description': clean_description}) | |
| # Print the generated description to the command line | |
| print(f"Generated description for subject '{subject}': {clean_description}") | |
| # Update used words and seed words | |
| used_words.add(subject) | |
| seed_words.append(clean_description) # Add the generated description to the seed bank array | |
| # Parse and append descriptions every 3 iterations | |
| if iteration_count % 3 == 0: | |
| parsed_descriptions = parse_descriptions(clean_description) | |
| parsed_descriptions_queue.extend(parsed_descriptions) | |
| # Return the parsed descriptions to update the Gradio UI | |
| return list(parsed_descriptions_queue) | |
| iteration_count += 1 | |
| return list(parsed_descriptions_queue) | |
| def generate_images_from_parsed_descriptions(): | |
| # If there are fewer than 13 descriptions, use whatever is available | |
| if len(parsed_descriptions_queue) < 13: | |
| prompts = list(parsed_descriptions_queue) | |
| else: | |
| prompts = [parsed_descriptions_queue.popleft() for _ in range(13)] | |
| # Generate images from the parsed descriptions | |
| images = [] | |
| for prompt in prompts: | |
| images.extend(pipe(prompt, num_images=1).images) | |
| return images | |
| # Create Gradio Interface | |
| description_interface = gr.Interface( | |
| fn=generate_and_store_descriptions, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), | |
| outputs="json" | |
| ) | |
| image_interface = gr.Interface( | |
| fn=generate_images_from_parsed_descriptions, | |
| inputs=None, | |
| outputs=gr.Gallery() | |
| ) | |
| gr.TabbedInterface([description_interface, image_interface], ["Generate Descriptions", "Generate Images"]).launch() | |