StickerMaker / app.py
Oranblock's picture
Update app.py
d33031f verified
raw
history blame
5.19 kB
import os
import random
import uuid
import json
import re
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
from diffusers import DiffusionPipeline
import face_recognition # More robust face detection library
from typing import Tuple
# Check if GPU is available; fallback to CPU if needed
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Initialize the AI model for sticker generation
pipe = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
).to(device)
def face_to_sticker(image_path: str) -> Tuple[str, str]:
"""Detect the face using face_recognition and convert it to a sticker format."""
img = face_recognition.load_image_file(image_path)
face_locations = face_recognition.face_locations(img)
if not face_locations:
return None, "No face detected. Please upload a clear image with a visible face."
# Extract the first detected face and return as an image for sticker creation
top, right, bottom, left = face_locations[0]
face_img = img[top:bottom, left:right]
face_img = Image.fromarray(face_img).resize((256, 256)) # Resize face to sticker size
face_img_path = f"{uuid.uuid4()}.png"
face_img.save(face_img_path)
return face_img_path, "Face successfully converted to a sticker."
def generate_prompt(clothing: str, pose: str, mood: str) -> str:
"""Generate a descriptive prompt based on user-selected clothing, pose, and mood."""
prompt = f"sticker of a person wearing {clothing} clothes, in a {pose} pose, looking {mood}."
return prompt
def generate_stickers(prompt: str, face_image: str, guidance_scale: float = 7.5, randomize_seed: bool = False):
"""Generate stickers using the diffusion model with the given prompt and face."""
# Adjust seed for variability
seed = random.randint(0, MAX_SEED) if randomize_seed else 42
generator = torch.Generator(device).manual_seed(seed)
# Prepare AI model options
options = {
"prompt": prompt,
"width": 512,
"height": 512,
"guidance_scale": guidance_scale,
"num_inference_steps": 50,
"generator": generator,
}
# Load the face as an input condition for the sticker (optional, if supported by the model)
# If your model supports conditioning on a specific face, load the face image here
# options['image'] = Image.open(face_image)
images = pipe(**options).images
image_paths = [save_image(img) for img in images]
return image_paths, seed
def save_image(img: Image.Image) -> str:
"""Save an image to a file and return the path."""
unique_name = f"{uuid.uuid4()}.png"
img.save(unique_name)
return unique_name
def stick_me_workflow(image, clothing, pose, mood, randomize_seed: bool):
"""Workflow to generate stickers based on user-uploaded image and options."""
# Convert the uploaded image to a face sticker
face_path, message = face_to_sticker(image)
if face_path is None:
return message # Return error message if face detection fails
# Generate a descriptive prompt based on user selections
prompt = generate_prompt(clothing, pose, mood)
# Generate stickers using the diffusion model with the extracted face and prompt
stickers, seed = generate_stickers(prompt, face_path, randomize_seed=randomize_seed)
return stickers
def on_fallback_to_cpu():
"""Notify users when the app is running on CPU (due to GPU quota being exceeded)."""
if not torch.cuda.is_available():
return "Warning: GPU quota exceeded. Running on CPU, which will be significantly slower."
return ""
# Gradio interface setup
with gr.Blocks() as demo:
gr.Markdown("# Sticker Generator with 'Stick Me' Feature")
# GPU Quota Handling
gpu_warning = gr.Markdown(on_fallback_to_cpu(), visible=not torch.cuda.is_available())
# New Stick Me Option
with gr.Row():
face_input = gr.Image(label="Upload Your Image for 'Stick Me'", type="filepath")
clothing = gr.Dropdown(["Casual", "Formal", "Sports"], label="Choose Clothing")
pose = gr.Dropdown(["Standing", "Sitting", "Running"], label="Choose Pose")
mood = gr.Dropdown(["Happy", "Serious", "Excited"], label="Choose Mood")
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
stick_me_button = gr.Button("Generate Stick Me Stickers")
stick_me_result = gr.Gallery(label="Your Stick Me Stickers")
stick_me_button.click(
fn=stick_me_workflow,
inputs=[face_input, clothing, pose, mood, randomize_seed],
outputs=[stick_me_result]
)
gr.Markdown("# Generate Regular Stickers")
prompt = gr.Textbox(label="Enter a Prompt for Sticker Creation", placeholder="Cute bunny", max_lines=1)
generate_button = gr.Button("Generate Stickers")
result = gr.Gallery(label="Generated Stickers")
generate_button.click(
fn=generate_stickers,
inputs=[prompt],
outputs=[result]
)
demo.launch()