HARRY07979's picture
Update app.py
8a325ae verified
import os
import time
import torch
import gradio as gr
from diffusers import DiffusionPipeline, LCMScheduler
MODEL_ID = "HyHorX/LiteVision-v1"
# ===== CPU SAFE SETUP =====
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
print("Loading model on CPU...")
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32, # MUST be float32 on CPU
)
# LCM scheduler
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
# CPU only
pipe.to("cpu")
pipe.enable_attention_slicing()
# Safety checker ON (default, do not disable)
assert pipe.safety_checker is not None, "Safety checker is missing"
print("Model loaded with safety checker enabled.")
def generate(prompt, steps, cfg, seed, progress=gr.Progress()):
start = time.time()
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device="cpu").manual_seed(seed)
print("===== GENERATE =====")
print("prompt :", prompt)
print("steps :", steps)
print("cfg :", cfg)
print("seed :", seed)
def cb(step, timestep, latents):
progress((step + 1) / steps, desc=f"Step {step+1}/{steps}")
result = pipe(
prompt=prompt,
num_inference_steps=steps, # LCM sweet spot: 4–6
guidance_scale=cfg, # LCM low CFG
width=512,
height=512,
generator=generator,
callback=cb,
callback_steps=1,
)
image = result.images[0]
if hasattr(result, "nsfw_content_detected") and result.nsfw_content_detected[0]:
print("NSFW detected -> image replaced by black frame")
image = torch.zeros((512, 512, 3), dtype=torch.uint8).numpy()
print(f"done in {time.time() - start:.2f}s")
return image, seed
with gr.Blocks() as demo:
gr.Markdown("# LiteVision-v1 • CPU LCM Demo ")
prompt = gr.Textbox(
label="Prompt",
lines=3,
placeholder="Describe your image"
)
steps = gr.Slider(1, 8, value=6, step=1, label="Steps (LCM recommended 4–6)")
cfg = gr.Slider(0.5, 3.0, value=1.5, step=0.1, label="CFG (LCM low CFG)")
seed = gr.Number(value=-1, label="Seed (-1 = random)")
run = gr.Button("Generate")
out_img = gr.Image(label="Result")
out_seed = gr.Number(label="Used seed")
run.click(
fn=generate,
inputs=[prompt, steps, cfg, seed],
outputs=[out_img, out_seed],
queue=True,
)
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)