SnapScribe / app.py
Aklavya's picture
Update app.py
35a6af1 verified
import os
import uuid
import random
from typing import Tuple, Optional
import gradio as gr
import numpy as np
from PIL import Image
import torch
import spaces
from diffusers import (
StableDiffusionXLPipeline,
StableDiffusionPipeline,
EulerAncestralDiscreteScheduler,
)
PRIMARY_MODEL_ID = "SG161222/RealVisXL_V5.0_Lightning" # requires access + token
FALLBACK_MODEL_ID = "stabilityai/sd-turbo" # public, fast 1.5-turbo
def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
styles = {
"3840 x 2160": (
"hyper-realistic image of {prompt}. lifelike, authentic, natural colors, "
"true-to-life details, landscape image, realistic lighting, immersive, highly detailed",
"unrealistic, low resolution, artificial, over-saturated, distorted, fake",
),
"Style Zero": ("{prompt}", ""),
}
DEFAULT_STYLE_NAME = "3840 x 2160"
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
n2 = (n + (" " + negative if negative else "")).strip()
return p.replace("{prompt}", positive), n2
def _enable_performance_knobs():
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_grad_enabled(False)
def _try_load_realvis(hf_token: Optional[str]):
use_cuda = torch.cuda.is_available()
dtype = torch.float16 if use_cuda else torch.float32
pipe = StableDiffusionXLPipeline.from_pretrained(
PRIMARY_MODEL_ID,
torch_dtype=dtype,
use_safetensors=True,
add_watermarker=False,
token=hf_token, # <- IMPORTANT
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
device = torch.device("cuda:0" if use_cuda else "cpu")
pipe = pipe.to(device)
return pipe
def _try_load_fallback():
# sd-turbo is Stable Diffusion 1.5 turbo; quick & public
use_cuda = torch.cuda.is_available()
dtype = torch.float16 if use_cuda else torch.float32
pipe = StableDiffusionPipeline.from_pretrained(
FALLBACK_MODEL_ID,
torch_dtype=dtype,
use_safetensors=True,
)
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
device = torch.device("cuda:0" if use_cuda else "cpu")
pipe = pipe.to(device)
return pipe
def load_and_prepare_model():
_enable_performance_knobs()
hf_token = os.getenv("HF_TOKEN", "").strip() or None
# Try RealVis first
try:
return _try_load_realvis(hf_token)
except Exception as e:
msg = str(e).lower()
if ("401" in msg or "403" in msg or "unauthorized" in msg or "forbidden" in msg):
# Clear hint in server logs; UI will still work via fallback.
print(
"\n[WARNING] Could not load RealVisXL (auth). "
"Make sure you've requested access and set HF_TOKEN in Space secrets.\n"
)
else:
print(f"\n[WARNING] RealVisXL failed to load: {e}\n")
# Fallback to sd-turbo so app still runs
print("[INFO] Falling back to stabilityai/sd-turbo (public).")
return _try_load_fallback()
# Load once
model = load_and_prepare_model()
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, np.iinfo(np.int32).max)
return int(seed)
def save_image(img: Image.Image) -> str:
unique_name = f"{uuid.uuid4().hex}.png"
img.save(unique_name)
return unique_name
@spaces.GPU(duration=60, enable_queue=True)
def generate(
prompt: str,
seed: int = 1,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 3.0,
num_inference_steps: int = 25,
randomize_seed: bool = False,
):
if not prompt or not prompt.strip():
raise gr.Error("Please enter a prompt.")
# Make dimensions friendly for SD models
width = int(max(256, (width // 8) * 8))
height = int(max(256, (height // 8) * 8))
seed = randomize_seed_fn(seed, randomize_seed)
generator = torch.Generator(device=model.device).manual_seed(seed)
# If model is SDXL pipeline, use the styled prompts; if fallback SD1.5 turbo, style still OK
positive_prompt, negative_prompt = apply_style("3840 x 2160", prompt)
# Some pipelines (sd-turbo) ignore guidance/steps or behave differently; passing is still safe
out = model(
prompt=positive_prompt,
negative_prompt=negative_prompt,
width=width if "xl" in model.__class__.__name__.lower() else None,
height=height if "xl" in model.__class__.__name__.lower() else None,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
output_type="pil",
)
# Handle both diffusers return shapes
images = getattr(out, "images", out)
image_path = save_image(images[0])
return image_path
with gr.Blocks(theme="soft") as demo:
with gr.Row():
with gr.Column(scale=12, elem_id="title_block"):
gr.Markdown(
"<h1 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>SNAPSCRIBE</h1>"
)
gr.Markdown(
"<h2 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>Developed with ❤ by Aklavya</h2>"
)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(
label="Input Prompt",
placeholder="Describe the image you want to create",
lines=2,
)
seed = gr.Number(value=1, label="Seed", precision=0)
randomize_seed = gr.Checkbox(value=True, label="Randomize Seed")
width = gr.Slider(512, 1536, value=1024, step=8, label="Width")
height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
guidance_scale = gr.Slider(1.0, 10.0, value=3.0, step=0.5, label="Guidance Scale")
steps = gr.Slider(10, 35, value=25, step=1, label="Inference Steps")
run_button = gr.Button("Generate Image", variant="primary")
example_prompts_text = (
"Dew-covered spider web in morning sunlight, with blurred greenery\n"
"--------------------------------------------\n"
"Glass of cold water with ice cubes and condensation on a wooden table\n"
"--------------------------------------------\n"
"Coffee cup with latte art, steam rising, and morning sunlight\n"
"--------------------------------------------\n"
"Autumn forest with golden leaves, sunlight through trees, and a breeze"
)
gr.Textbox(
value=example_prompts_text,
lines=8,
label="Sample Inputs",
interactive=False,
)
with gr.Column(scale=7):
result_image = gr.Image(
label="Generated Image",
type="filepath",
elem_id="output_image",
)
run_button.click(
fn=generate,
inputs=[prompt, seed, width, height, guidance_scale, steps, randomize_seed],
outputs=[result_image],
api_name="generate",
)
if __name__ == "__main__":
demo.launch()