Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -83,68 +83,41 @@ def initialize_diffusers():
|
|
| 83 |
pipe = initialize_diffusers()
|
| 84 |
print("Models and checkpoints preloaded.")
|
| 85 |
|
| 86 |
-
def generate_description_prompt(
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
try:
|
| 89 |
-
generated_text = text_generator(
|
| 90 |
-
|
| 91 |
-
return
|
| 92 |
except Exception as e:
|
| 93 |
-
print(f"Error generating
|
| 94 |
return None
|
| 95 |
|
| 96 |
-
def parse_descriptions(text):
|
| 97 |
-
descriptions = re.findall(r'\[([^\[\]]+)\]', text)
|
| 98 |
-
descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
|
| 99 |
-
return descriptions
|
| 100 |
-
|
| 101 |
def format_descriptions(descriptions):
|
| 102 |
formatted_descriptions = "\n".join(descriptions)
|
| 103 |
return formatted_descriptions
|
| 104 |
|
| 105 |
@spaces.GPU
|
| 106 |
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=1): # Set max_iterations to 1
|
| 107 |
-
descriptions =
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
|
| 112 |
-
|
| 113 |
-
for _ in range(2): # Perform two iterations
|
| 114 |
-
while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
|
| 115 |
-
available_subjects = [word for word in seed_words if word not in used_words]
|
| 116 |
-
if not available_subjects:
|
| 117 |
-
print("No more available subjects to use.")
|
| 118 |
-
break
|
| 119 |
-
|
| 120 |
-
subject = random.choice(available_subjects)
|
| 121 |
-
generated_description = generate_description_prompt(subject, user_prompt, text_generator)
|
| 122 |
-
|
| 123 |
-
if generated_description:
|
| 124 |
-
clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
|
| 125 |
-
description_queue.append({'subject': subject, 'description': clean_description})
|
| 126 |
-
|
| 127 |
-
print(f"Generated description for subject '{subject}': {clean_description}")
|
| 128 |
-
|
| 129 |
-
used_words.add(subject)
|
| 130 |
-
seed_words.append(clean_description)
|
| 131 |
-
|
| 132 |
-
parsed_descriptions = parse_descriptions(clean_description)
|
| 133 |
-
parsed_descriptions_queue.extend(parsed_descriptions)
|
| 134 |
-
|
| 135 |
-
iteration_count += 1
|
| 136 |
-
|
| 137 |
-
return list(parsed_descriptions_queue)
|
| 138 |
|
| 139 |
@spaces.GPU(duration=120)
|
| 140 |
def generate_images(parsed_descriptions, max_iterations=3): # Set max_iterations to 3
|
| 141 |
-
# Limit the number of descriptions passed to the image generator to
|
| 142 |
if len(parsed_descriptions) > MAX_IMAGES:
|
| 143 |
parsed_descriptions = parsed_descriptions[:MAX_IMAGES]
|
| 144 |
|
| 145 |
images = []
|
| 146 |
for prompt in parsed_descriptions:
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
return images
|
| 150 |
|
|
@@ -161,10 +134,10 @@ if __name__ == '__main__':
|
|
| 161 |
|
| 162 |
interface = gr.Interface(
|
| 163 |
fn=generate_and_display,
|
| 164 |
-
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter
|
| 165 |
outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
|
| 166 |
live=False, # Set live to False
|
| 167 |
allow_flagging='never' # Disable flagging
|
| 168 |
)
|
| 169 |
|
| 170 |
-
interface.launch(share=True)
|
|
|
|
| 83 |
pipe = initialize_diffusers()
|
| 84 |
print("Models and checkpoints preloaded.")
|
| 85 |
|
| 86 |
+
def generate_description_prompt(user_prompt, text_generator):
|
| 87 |
+
injected_prompt = f"write three concise descriptions enclosed in brackets like [ <description> ] less than 100 words each of {user_prompt}. "
|
| 88 |
+
max_length = 110 # Set max token length to 110
|
| 89 |
+
|
| 90 |
try:
|
| 91 |
+
generated_text = text_generator(injected_prompt, max_length=max_length, num_return_sequences=1, truncation=True)[0]['generated_text']
|
| 92 |
+
generated_descriptions = re.findall(r'\[([^\[\]]+)\]', generated_text) # Extract descriptions enclosed in brackets
|
| 93 |
+
return generated_descriptions if generated_descriptions else None
|
| 94 |
except Exception as e:
|
| 95 |
+
print(f"Error generating descriptions: {e}")
|
| 96 |
return None
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def format_descriptions(descriptions):
|
| 99 |
formatted_descriptions = "\n".join(descriptions)
|
| 100 |
return formatted_descriptions
|
| 101 |
|
| 102 |
@spaces.GPU
|
| 103 |
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=1): # Set max_iterations to 1
|
| 104 |
+
descriptions = generate_description_prompt(user_prompt, text_generator)
|
| 105 |
+
if descriptions:
|
| 106 |
+
parsed_descriptions_queue.extend(descriptions)
|
| 107 |
+
return list(parsed_descriptions_queue)[:MAX_IMAGES]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
@spaces.GPU(duration=120)
|
| 110 |
def generate_images(parsed_descriptions, max_iterations=3): # Set max_iterations to 3
|
| 111 |
+
# Limit the number of descriptions passed to the image generator to MAX_IMAGES (3)
|
| 112 |
if len(parsed_descriptions) > MAX_IMAGES:
|
| 113 |
parsed_descriptions = parsed_descriptions[:MAX_IMAGES]
|
| 114 |
|
| 115 |
images = []
|
| 116 |
for prompt in parsed_descriptions:
|
| 117 |
+
try:
|
| 118 |
+
images.extend(pipe(prompt, num_inference_steps=4, height=1024, width=1024).images) # Set resolution to 512 x 512
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error generating image for prompt '{prompt}': {e}")
|
| 121 |
|
| 122 |
return images
|
| 123 |
|
|
|
|
| 134 |
|
| 135 |
interface = gr.Interface(
|
| 136 |
fn=generate_and_display,
|
| 137 |
+
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter example in quotes, e.g., "cat", "dog", "sunset"...')],
|
| 138 |
outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
|
| 139 |
live=False, # Set live to False
|
| 140 |
allow_flagging='never' # Disable flagging
|
| 141 |
)
|
| 142 |
|
| 143 |
+
interface.launch(share=True)
|