Update app.py
Browse files
app.py
CHANGED
|
@@ -14,9 +14,7 @@ from torchvision import transforms
|
|
| 14 |
from torchvision.transforms import functional as TF
|
| 15 |
from tqdm import trange
|
| 16 |
from transformers import CLIPProcessor, CLIPModel
|
| 17 |
-
|
| 18 |
-
# from diffusion_models import Diffusion # Swapped Diffusion model for DALL·E 2 based model - REMOVED
|
| 19 |
-
from huggingface_hub import hf_hub_url, cached_download
|
| 20 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
| 21 |
import math
|
| 22 |
|
|
@@ -130,8 +128,9 @@ def ddpm_sample(model, x, steps, **kwargs):
|
|
| 130 |
# NOTE: The HuggingFace URLs you provided might be placeholders.
|
| 131 |
# Make sure these point to the correct model files.
|
| 132 |
try:
|
| 133 |
-
|
| 134 |
-
|
|
|
|
| 135 |
except Exception as e:
|
| 136 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
| 137 |
print("Using placeholder models which will not produce good images.")
|
|
@@ -213,7 +212,7 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
| 213 |
target_embeds.append(text_embed)
|
| 214 |
weights.append(1.0)
|
| 215 |
|
| 216 |
-
#
|
| 217 |
# Assign a default weight for image prompts
|
| 218 |
image_prompt_weight = 1.0
|
| 219 |
for image_path in images:
|
|
@@ -250,7 +249,6 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
| 250 |
return v
|
| 251 |
|
| 252 |
# 🎞️ Run the sampler to generate images
|
| 253 |
-
# **FIXED**: Call sampling functions directly without the 'sampling.' prefix
|
| 254 |
def run(x, steps):
|
| 255 |
if method == 'ddpm':
|
| 256 |
return ddpm_sample(cfg_model_fn, x, steps)
|
|
@@ -310,7 +308,6 @@ iface = gr.Interface(
|
|
| 310 |
fn=gen_ims,
|
| 311 |
inputs=[
|
| 312 |
gr.Textbox(label="Text prompt"),
|
| 313 |
-
# **FIXED**: Removed deprecated 'optional=True' argument
|
| 314 |
gr.Image(label="Image prompt", type='filepath')
|
| 315 |
],
|
| 316 |
outputs=gr.Image(type="pil", label="Generated Image"),
|
|
|
|
| 14 |
from torchvision.transforms import functional as TF
|
| 15 |
from tqdm import trange
|
| 16 |
from transformers import CLIPProcessor, CLIPModel
|
| 17 |
+
from huggingface_hub import hf_hub_download # FIXED: Replaced deprecated function
|
|
|
|
|
|
|
| 18 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
| 19 |
import math
|
| 20 |
|
|
|
|
| 128 |
# NOTE: The HuggingFace URLs you provided might be placeholders.
|
| 129 |
# Make sure these point to the correct model files.
|
| 130 |
try:
|
| 131 |
+
# FIXED: Using the new hf_hub_download function with keyword arguments
|
| 132 |
+
vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
|
| 133 |
+
diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
|
| 134 |
except Exception as e:
|
| 135 |
print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
|
| 136 |
print("Using placeholder models which will not produce good images.")
|
|
|
|
| 212 |
target_embeds.append(text_embed)
|
| 213 |
weights.append(1.0)
|
| 214 |
|
| 215 |
+
# Correctly process image prompts from Gradio
|
| 216 |
# Assign a default weight for image prompts
|
| 217 |
image_prompt_weight = 1.0
|
| 218 |
for image_path in images:
|
|
|
|
| 249 |
return v
|
| 250 |
|
| 251 |
# 🎞️ Run the sampler to generate images
|
|
|
|
| 252 |
def run(x, steps):
|
| 253 |
if method == 'ddpm':
|
| 254 |
return ddpm_sample(cfg_model_fn, x, steps)
|
|
|
|
| 308 |
fn=gen_ims,
|
| 309 |
inputs=[
|
| 310 |
gr.Textbox(label="Text prompt"),
|
|
|
|
| 311 |
gr.Image(label="Image prompt", type='filepath')
|
| 312 |
],
|
| 313 |
outputs=gr.Image(type="pil", label="Generated Image"),
|