generate_game / app.py
chrisjcc's picture
Remove unused code
34f1bac verified
import os
import io
from PIL import Image
import base64
import torch
import spaces
from transformers import pipeline
from diffusers import EulerDiscreteScheduler
from diffusers import StableDiffusionPipeline
import gradio as gr
# Move pipeline to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set Hugging Face API (needed for gated models)
hf_api_key = os.environ.get('HF_API_KEY')
# Load the Stable Diffusion pipeline
model_id = "sd-legacy/stable-diffusion-v1-5"
# Use the Euler scheduler here instead
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
# Load the image-to-text pipeline with BLIP model
get_completion = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
# Load the Stable Diffusion pipeline
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Use float16 on GPU, float32 on CPU
scheduler=scheduler,
use_auth_token=hf_api_key # Required for gated model
)
pipe = pipe.to(device)
# Caption generate function
@spaces.GPU(duration=120) # Designed to be effect-free in non-ZeroGPU environments, ensuring compatibility across different setups.
def captioner(image):
# The BLIP model expects a PIL image directly
result = get_completion(image)
return result[0]['generated_text']
# Image generate function
@spaces.GPU(duration=120) # Designed to be effect-free in non-ZeroGPU environments, ensuring compatibility across different setups.
def generate(prompt, steps):
# Generate an image with Stable Diffusion
output = pipe(
prompt,
negative_prompt=None, # Handle empty negative prompt
num_inference_steps=25,
)
return output.images[0] # Return the first generated image (PIL format)
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Describe-and-Generate game 🖍️")
image_upload = gr.Image(label="Your first image",type="pil")
btn_caption = gr.Button("Generate caption")
caption = gr.Textbox(label="Generated caption")
btn_image = gr.Button("Generate image")
image_output = gr.Image(label="Generated Image")
btn_caption.click(fn=captioner, inputs=[image_upload], outputs=[caption])
btn_image.click(fn=generate, inputs=[caption], outputs=[image_output])
# Launch the app
demo.launch(
share=True,
#server_port=int(os.environ['PORT4'])
)