Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from diffusers import AutoPipelineForImage2Image | |
| # ------------------------- | |
| # 1) 年龄估计模型 | |
| # ------------------------- | |
| AGE_MODEL_ID = "nateraw/vit-age-classifier" | |
| age_processor = AutoImageProcessor.from_pretrained(AGE_MODEL_ID) | |
| age_model = AutoModelForImageClassification.from_pretrained(AGE_MODEL_ID) | |
| age_model.eval() | |
| def _label_to_age(label: str) -> float: | |
| label = label.strip().replace("(", "").replace(")", "") | |
| if "-" in label: | |
| a, b = label.split("-") | |
| try: | |
| return (float(a) + float(b)) / 2.0 | |
| except: | |
| pass | |
| try: | |
| return float(label) | |
| except: | |
| return float("nan") | |
| def estimate_age(image: Image.Image) -> dict: | |
| inputs = age_processor(images=image, return_tensors="pt") | |
| logits = age_model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| id2label = age_model.config.id2label | |
| topk = torch.topk(probs, k=min(5, probs.shape[0])) | |
| items, ages = [], [] | |
| for score, idx in zip(topk.values.tolist(), topk.indices.tolist()): | |
| label = id2label[idx] | |
| age = _label_to_age(label) | |
| ages.append((age, score)) | |
| items.append(f"{label}: {score*100:.1f}%") | |
| ages_valid = [(a, p) for a, p in ages if not math.isnan(a)] | |
| if ages_valid: | |
| num = sum(a * p for a, p in ages_valid) | |
| den = sum(p for _, p in ages_valid) | |
| expected_age = num / den | |
| else: | |
| expected_age = float("nan") | |
| return { | |
| "expected_age": None if math.isnan(expected_age) else round(expected_age, 1), | |
| "top5": items | |
| } | |
| # ------------------------- | |
| # 2) 漫画风格生成(img2img) | |
| # ------------------------- | |
| IMG2IMG_MODEL_ID = "stabilityai/sd-turbo" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| pipe = AutoPipelineForImage2Image.from_pretrained( | |
| IMG2IMG_MODEL_ID, | |
| torch_dtype=dtype | |
| ).to(device) | |
| DEFAULT_PROMPT = ( | |
| "comic style, manga, cel-shaded, bold ink outlines, clean lineart, high contrast, " | |
| "professional illustration, vibrant" | |
| ) | |
| NEG_PROMPT = "realistic, photorealistic, blurry, noisy, artifacts, watermark, text" | |
| def stylize_to_comic( | |
| image: Image.Image, | |
| prompt: str = DEFAULT_PROMPT, | |
| strength: float = 0.6, | |
| guidance_scale: float = 0.0, | |
| steps: int = 4, | |
| seed: int | None = 42 | |
| ) -> Image.Image: | |
| generator = None if (seed is None or seed < 0) else torch.Generator(device=device).manual_seed(int(seed)) | |
| image = image.convert("RGB") | |
| out = pipe( | |
| prompt=prompt, | |
| negative_prompt=NEG_PROMPT, | |
| image=image, | |
| strength=float(strength), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance_scale), | |
| generator=generator, | |
| ) | |
| return out.images[0] | |
| # ------------------------- | |
| # 3) Gradio 界面(两个按钮都在最上面) | |
| # ------------------------- | |
| def ui_estimate_age(image): | |
| if image is None: | |
| return "请先上传图片" | |
| res = estimate_age(image) | |
| if res["expected_age"] is None: | |
| return "年龄估计:解析失败(可能检测不到年龄标签)" | |
| return f"年龄估计:≈ {res['expected_age']} 岁\nTop-5: " + " | ".join(res["top5"]) | |
| def ui_stylize(image, prompt, strength, guidance, steps, seed): | |
| if image is None: | |
| return None | |
| return stylize_to_comic( | |
| image=image, | |
| prompt=prompt, | |
| strength=strength, | |
| guidance_scale=guidance, | |
| steps=int(steps), | |
| seed=int(seed) if seed is not None else 42 | |
| ) | |
| with gr.Blocks(title="Age & Comicify Agent") as demo: | |
| gr.Markdown("# 🧠 Age & Comicify Agent\n上传图片 → ① 估计年龄 ② 生成漫画风格图片") | |
| # 顶部两个按钮 | |
| with gr.Row(): | |
| btn_est = gr.Button("🧮 估计年龄", variant="primary") | |
| btn_gen = gr.Button("🎨 生成漫画图片", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_img = gr.Image(label="上传图片", type="pil") | |
| prompt = gr.Textbox(label="风格提示词", value=DEFAULT_PROMPT) | |
| strength = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="风格强度(strength)") | |
| guidance = gr.Slider(0.0, 3.0, value=0.0, step=0.1, label="引导系数(guidance_scale)") | |
| steps = gr.Slider(2, 12, value=4, step=1, label="步数(num_inference_steps)") | |
| seed = gr.Number(value=42, precision=0, label="随机种子(固定可复现)") | |
| with gr.Column(scale=1): | |
| age_txt = gr.Textbox(label="年龄估计结果") | |
| out_img = gr.Image(label="漫画风格输出") | |
| # 绑定:按钮各自只触发一个功能 | |
| btn_est.click(fn=ui_estimate_age, inputs=[in_img], outputs=[age_txt]) | |
| btn_gen.click(fn=ui_stylize, inputs=[in_img, prompt, strength, guidance, steps, seed], outputs=[out_img]) | |
| if __name__ == "__main__": | |
| # 可选:并发/队列 | |
| demo.queue().launch() | |