Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import warnings | |
| from transformers import CLIPProcessor, CLIPModel | |
| warnings.filterwarnings("ignore") | |
| # Check if CUDA is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load CLIP model for semantic guidance | |
| print("Loading CLIP model for semantic guidance...") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Dictionary of available concepts | |
| CONCEPTS = { | |
| "canna-lily-flowers102": { | |
| "repo_id": "sd-concepts-library/canna-lily-flowers102", | |
| "type": "object", | |
| "description": "Canna lily flower style" | |
| }, | |
| "samurai-jack": { | |
| "repo_id": "sd-concepts-library/samurai-jack", | |
| "type": "style", | |
| "description": "Samurai Jack animation style" | |
| }, | |
| "babies-poster": { | |
| "repo_id": "sd-concepts-library/babies-poster", | |
| "type": "style", | |
| "description": "Babies poster art style" | |
| }, | |
| "animal-toy": { | |
| "repo_id": "sd-concepts-library/animal-toy", | |
| "type": "object", | |
| "description": "Animal toy style" | |
| }, | |
| "sword-lily-flowers102": { | |
| "repo_id": "sd-concepts-library/sword-lily-flowers102", | |
| "type": "object", | |
| "description": "Sword lily flower style" | |
| } | |
| } | |
| def car_loss(image): | |
| """Custom loss function that encourages the presence of cars in the image""" | |
| # Convert PIL image to tensor if needed | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| image = torch.tensor(image, device=device) | |
| # Process image for CLIP | |
| with torch.no_grad(): | |
| # Convert to PIL for CLIP processing | |
| pil_image = Image.fromarray(image.cpu().numpy().astype(np.uint8)) | |
| # Get CLIP features for the image | |
| inputs = clip_processor( | |
| text=["a photo of a car", "a photo without cars"], | |
| images=pil_image, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(device) | |
| # Get similarity scores | |
| outputs = clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| # Higher score for the first text (with cars) is better | |
| car_score = logits_per_image[0][0] | |
| no_car_score = logits_per_image[0][1] | |
| # We want to maximize car_score and minimize no_car_score | |
| loss = -(car_score - no_car_score) | |
| return loss | |
| def generate_image(pipe, prompt, seed, guidance_scale=7.5, num_inference_steps=30, use_car_guidance=False): | |
| """Generate an image with optional car guidance""" | |
| generator = torch.Generator(device).manual_seed(seed) | |
| custom_loss = car_loss if use_car_guidance else None | |
| if custom_loss: | |
| try: | |
| # Start with a standard generation | |
| init_images = pipe( | |
| prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps // 2, | |
| generator=generator | |
| ).images | |
| init_image = init_images[0] | |
| # Refine using car guidance | |
| from diffusers import StableDiffusionImg2ImgPipeline | |
| img2img_pipe = StableDiffusionImg2ImgPipeline( | |
| vae=pipe.vae, | |
| text_encoder=pipe.text_encoder, | |
| tokenizer=pipe.tokenizer, | |
| unet=pipe.unet, | |
| scheduler=pipe.scheduler, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| ).to(device) | |
| strength = 0.75 | |
| current_image = init_image | |
| for i in range(5): | |
| current_loss = custom_loss(current_image) | |
| refined_images = img2img_pipe( | |
| prompt=prompt + ", with beautiful cars", | |
| image=current_image, | |
| strength=strength, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ).images | |
| current_image = refined_images[0] | |
| strength *= 0.8 | |
| return current_image | |
| except Exception as e: | |
| print(f"Error in car-guided generation: {e}") | |
| return pipe( | |
| prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator | |
| ).images[0] | |
| else: | |
| return pipe( | |
| prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator | |
| ).images[0] | |
| # Cache for loaded models and concepts | |
| loaded_models = {} | |
| def get_model_with_concept(concept_name): | |
| """Get or load a model with the specified concept""" | |
| if concept_name not in loaded_models: | |
| concept_info = CONCEPTS[concept_name] | |
| # Download concept embedding | |
| concept_path = f"concepts/{concept_name}.bin" | |
| os.makedirs("concepts", exist_ok=True) | |
| if not os.path.exists(concept_path): | |
| file = hf_hub_download( | |
| repo_id=concept_info["repo_id"], | |
| filename="learned_embeds.bin", | |
| repo_type="model" | |
| ) | |
| import shutil | |
| shutil.copy(file, concept_path) | |
| # Load model and concept | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2", | |
| torch_dtype=torch.float32 if device == "cpu" else torch.float16, | |
| safety_checker=None | |
| ).to(device) | |
| pipe.load_textual_inversion(concept_path) | |
| loaded_models[concept_name] = pipe | |
| return loaded_models[concept_name] | |
| def generate_images(concept_name, base_prompt, seed, use_car_guidance): | |
| """Generate images using the selected concept""" | |
| try: | |
| # Get model with concept | |
| pipe = get_model_with_concept(concept_name) | |
| # Construct prompt based on concept type | |
| if CONCEPTS[concept_name]["type"] == "object": | |
| prompt = f"A {base_prompt} with a <{concept_name}>" | |
| else: | |
| prompt = f"<{concept_name}> {base_prompt}" | |
| # Generate image | |
| image = generate_image( | |
| pipe=pipe, | |
| prompt=prompt, | |
| seed=int(seed), | |
| use_car_guidance=use_car_guidance | |
| ) | |
| return image | |
| except Exception as e: | |
| raise gr.Error(f"Error generating image: {str(e)}") | |
| # Create Gradio interface | |
| with gr.Blocks(title="Stable Diffusion Style Explorer") as demo: | |
| gr.Markdown(""" | |
| # Stable Diffusion Style Explorer | |
| Generate images using various concepts from the SD Concepts Library, with optional car guidance. | |
| ## How to use: | |
| 1. Select a concept from the dropdown | |
| 2. Enter a base prompt (or use the default) | |
| 3. Set a seed for reproducibility | |
| 4. Choose whether to use car guidance | |
| 5. Click Generate! | |
| Check out the examples below to see different combinations of concepts and prompts! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| concept = gr.Dropdown( | |
| choices=list(CONCEPTS.keys()), | |
| value="samurai-jack", | |
| label="Select Concept" | |
| ) | |
| prompt = gr.Textbox( | |
| value="A serene landscape with mountains and a lake at sunset", | |
| label="Base Prompt" | |
| ) | |
| seed = gr.Number( | |
| value=42, | |
| label="Seed", | |
| precision=0 | |
| ) | |
| car_guidance = gr.Checkbox( | |
| value=False, | |
| label="Use Car Guidance" | |
| ) | |
| generate_btn = gr.Button("Generate Image") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image") | |
| concept.change( | |
| fn=lambda x: gr.Markdown(f"Selected concept: {CONCEPTS[x]['description']} ({CONCEPTS[x]['type']})"), | |
| inputs=[concept], | |
| outputs=[gr.Markdown()] | |
| ) | |
| generate_btn.click( | |
| fn=generate_images, | |
| inputs=[concept, prompt, seed, car_guidance], | |
| outputs=[output_image] | |
| ) | |
| # Gallery of pre-generated examples | |
| gr.Markdown("### 🖼️ Pre-generated Examples") | |
| with gr.Row(): | |
| # Samurai Jack examples | |
| with gr.Column(): | |
| gr.Markdown("**Samurai Jack Style**") | |
| gr.Image("samurai-jack_normal.png", | |
| label="Without Car Guidance") | |
| gr.Image("samurai-jack_car.png", | |
| label="With Car Guidance") | |
| with gr.Row(): | |
| # Canna Lily examples | |
| with gr.Column(): | |
| gr.Markdown("**Canna Lily Object**") | |
| gr.Image("canna-lily-flowers102_normal.png", | |
| label="Without Car Guidance") | |
| gr.Image("canna-lily-flowers102_car.png", | |
| label="With Car Guidance") | |
| with gr.Row(): | |
| # Babies Poster examples | |
| with gr.Column(): | |
| gr.Markdown("**Babies Poster Style**") | |
| gr.Image("babies-poster_normal.png", | |
| label="Without Car Guidance") | |
| gr.Image("babies-poster_car.png", | |
| label="With Car Guidance") | |
| with gr.Row(): | |
| # Animal Toy examples | |
| with gr.Column(): | |
| gr.Markdown("**Animal Toy Object**") | |
| gr.Image("animal-toy_normal.png", | |
| label="Without Car Guidance") | |
| gr.Image("animal-toy_car.png", | |
| label="With Car Guidance") | |
| with gr.Row(): | |
| # Sword Lily examples | |
| with gr.Column(): | |
| gr.Markdown("**Sword Lily Object**") | |
| gr.Image("sword-lily-flowers102_normal.png", | |
| label="Without Car Guidance") | |
| gr.Image("sword-lily-flowers102_car.png", | |
| label="With Car Guidance") | |
| demo.launch() |