img2img / app.py
VirginiaZane's picture
Update app.py
b7380ce verified
import os
import math
import torch
import gradio as gr
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from diffusers import AutoPipelineForImage2Image
# -------------------------
# 1) 年龄估计模型
# -------------------------
AGE_MODEL_ID = "nateraw/vit-age-classifier"
age_processor = AutoImageProcessor.from_pretrained(AGE_MODEL_ID)
age_model = AutoModelForImageClassification.from_pretrained(AGE_MODEL_ID)
age_model.eval()
def _label_to_age(label: str) -> float:
label = label.strip().replace("(", "").replace(")", "")
if "-" in label:
a, b = label.split("-")
try:
return (float(a) + float(b)) / 2.0
except:
pass
try:
return float(label)
except:
return float("nan")
@torch.inference_mode()
def estimate_age(image: Image.Image) -> dict:
inputs = age_processor(images=image, return_tensors="pt")
logits = age_model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
id2label = age_model.config.id2label
topk = torch.topk(probs, k=min(5, probs.shape[0]))
items, ages = [], []
for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
label = id2label[idx]
age = _label_to_age(label)
ages.append((age, score))
items.append(f"{label}: {score*100:.1f}%")
ages_valid = [(a, p) for a, p in ages if not math.isnan(a)]
if ages_valid:
num = sum(a * p for a, p in ages_valid)
den = sum(p for _, p in ages_valid)
expected_age = num / den
else:
expected_age = float("nan")
return {
"expected_age": None if math.isnan(expected_age) else round(expected_age, 1),
"top5": items
}
# -------------------------
# 2) 漫画风格生成(img2img)
# -------------------------
IMG2IMG_MODEL_ID = "stabilityai/sd-turbo"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = AutoPipelineForImage2Image.from_pretrained(
IMG2IMG_MODEL_ID,
torch_dtype=dtype
).to(device)
DEFAULT_PROMPT = (
"comic style, manga, cel-shaded, bold ink outlines, clean lineart, high contrast, "
"professional illustration, vibrant"
)
NEG_PROMPT = "realistic, photorealistic, blurry, noisy, artifacts, watermark, text"
@torch.inference_mode()
def stylize_to_comic(
image: Image.Image,
prompt: str = DEFAULT_PROMPT,
strength: float = 0.6,
guidance_scale: float = 0.0,
steps: int = 4,
seed: int | None = 42
) -> Image.Image:
generator = None if (seed is None or seed < 0) else torch.Generator(device=device).manual_seed(int(seed))
image = image.convert("RGB")
out = pipe(
prompt=prompt,
negative_prompt=NEG_PROMPT,
image=image,
strength=float(strength),
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
generator=generator,
)
return out.images[0]
# -------------------------
# 3) Gradio 界面(两个按钮都在最上面)
# -------------------------
def ui_estimate_age(image):
if image is None:
return "请先上传图片"
res = estimate_age(image)
if res["expected_age"] is None:
return "年龄估计:解析失败(可能检测不到年龄标签)"
return f"年龄估计:≈ {res['expected_age']} 岁\nTop-5: " + " | ".join(res["top5"])
def ui_stylize(image, prompt, strength, guidance, steps, seed):
if image is None:
return None
return stylize_to_comic(
image=image,
prompt=prompt,
strength=strength,
guidance_scale=guidance,
steps=int(steps),
seed=int(seed) if seed is not None else 42
)
with gr.Blocks(title="Age & Comicify Agent") as demo:
gr.Markdown("# 🧠 Age & Comicify Agent\n上传图片 → ① 估计年龄 ② 生成漫画风格图片")
# 顶部两个按钮
with gr.Row():
btn_est = gr.Button("🧮 估计年龄", variant="primary")
btn_gen = gr.Button("🎨 生成漫画图片", variant="secondary")
with gr.Row():
with gr.Column(scale=1):
in_img = gr.Image(label="上传图片", type="pil")
prompt = gr.Textbox(label="风格提示词", value=DEFAULT_PROMPT)
strength = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="风格强度(strength)")
guidance = gr.Slider(0.0, 3.0, value=0.0, step=0.1, label="引导系数(guidance_scale)")
steps = gr.Slider(2, 12, value=4, step=1, label="步数(num_inference_steps)")
seed = gr.Number(value=42, precision=0, label="随机种子(固定可复现)")
with gr.Column(scale=1):
age_txt = gr.Textbox(label="年龄估计结果")
out_img = gr.Image(label="漫画风格输出")
# 绑定:按钮各自只触发一个功能
btn_est.click(fn=ui_estimate_age, inputs=[in_img], outputs=[age_txt])
btn_gen.click(fn=ui_stylize, inputs=[in_img, prompt, strength, guidance, steps, seed], outputs=[out_img])
if __name__ == "__main__":
# 可选:并发/队列
demo.queue().launch()