Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import logging | |
| from typing import Tuple, Dict | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from runners.simple_runner import SimpleRunner | |
| # ----------------------------------------------------------------------------- | |
| # Logging (use lazy % formatting as requested) | |
| # ----------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("sfe-app") | |
| # ----------------------------------------------------------------------------- | |
| # Model bootstrap (load once and reuse) | |
| # ----------------------------------------------------------------------------- | |
| RUNNER: SimpleRunner | None = None | |
| def ensure_weights(): | |
| """Make sure pretrained weights exist locally; otherwise fetch from your HF model repo.""" | |
| need = [ | |
| "pretrained_models/sfe_editor_light.pt", | |
| "pretrained_models/stylegan2-ffhq-config-f.pt", | |
| ] | |
| if all(os.path.exists(p) for p in need): | |
| return | |
| repo_id = "LogicGoInfotechSpaces/Smile_Changer_pre_model" | |
| logger.info("Missing weights; downloading snapshot from %s", repo_id) | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir="pretrained_models", | |
| local_dir_use_symlinks=False, | |
| allow_patterns=["**/*"], | |
| ) | |
| def get_runner() -> SimpleRunner: | |
| global RUNNER | |
| if RUNNER is None: | |
| ensure_weights() | |
| logger.info("Initializing SimpleRunner with %s", "pretrained_models/sfe_editor_light.pt") | |
| RUNNER = SimpleRunner( | |
| editor_ckpt_pth="pretrained_models/sfe_editor_light.pt", | |
| ) | |
| return RUNNER | |
| # ----------------------------------------------------------------------------- | |
| # Attribute catalog and recommended ranges | |
| # ----------------------------------------------------------------------------- | |
| # Each entry maps a friendly attribute name to the internal editing name and a | |
| # recommended power range for the slider. | |
| ATTRIBUTE_MAP: Dict[str, Tuple[str, Tuple[float, float]]] = { | |
| # Face semantics | |
| "Smile": ("fs_smiling", (-10.0, 10.0)), | |
| "Age": ("age", (-10.0, 10.0)), # interfacegan_directions | |
| "Female features": ("gender", (-10.0, 7.0)), # stylespace_directions (positive adds femininity) | |
| # Facial hair | |
| # trimmed_beard removes beard for positive power; use negative to add | |
| "Beard": ("trimmed_beard", (-30.0, 30.0)), | |
| # goatee removes goatee for positive; negative tends to add | |
| "Mustache/Goatee": ("goatee", (-7.0, 7.0)), | |
| # Accessories & cosmetics | |
| "Glasses": ("fs_glasses", (-20.0, 30.0)), | |
| "Makeup": ("fs_makeup", (-10.0, 15.0)), | |
| # Hair style (pretrained mappers) | |
| "Curly hair": ("curly_hair", (0.0, 0.12)), # styleclip_directions | |
| "Afro": ("afro", (0.0, 0.14)), | |
| # Hair color via global text mapper | |
| # You can also type custom prompts below | |
| "Orange hair (text)": ("styleclip_global_a face_a face with orange hair_0.18", (0.0, 0.2)), | |
| "Blonde hair (text)": ("styleclip_global_a face_a face with blonde hair_0.18", (0.0, 0.2)), | |
| } | |
| def recommended_range(attr_name: str) -> Tuple[float, float]: | |
| edit_name, rng = ATTRIBUTE_MAP[attr_name] | |
| return rng | |
| def run_edit( | |
| image: Image.Image, | |
| attribute: str, | |
| strength: float, | |
| align_face: bool, | |
| use_bg_mask: bool, | |
| custom_text_edit: str, | |
| ) -> Image.Image: | |
| """Run a single attribute edit and return the edited image.""" | |
| runner = get_runner() | |
| # Determine editing name and clip strength into the suggested range | |
| edit_name, (lo, hi) = ATTRIBUTE_MAP[attribute] | |
| if custom_text_edit and attribute.endswith("(text)"): | |
| # Allow overriding the default text prompt | |
| if custom_text_edit.strip(): | |
| edit_name = custom_text_edit.strip() | |
| clipped_strength = max(lo, min(hi, strength)) | |
| if clipped_strength != strength: | |
| logger.info("Clipped strength from %s to %s for %s", strength, clipped_strength, attribute) | |
| # Persist input to a temp file for the runner | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| inp_path = os.path.join(tmpdir, "input.jpg") | |
| out_path = os.path.join(tmpdir, "edited.jpg") | |
| image.convert("RGB").save(inp_path) | |
| logger.info("Editing %s with power %s", edit_name, clipped_strength) | |
| _ = runner.edit( | |
| orig_img_pth=inp_path, | |
| editing_name=edit_name, | |
| edited_power=clipped_strength, | |
| save_pth=out_path, | |
| align=align_face, | |
| use_mask=use_bg_mask, | |
| ) | |
| return Image.open(out_path).convert("RGB") | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
| gr.Markdown(""" | |
| **StyleFeatureEditor β Facial Attribute Editing** | |
| Upload a face and apply edits like smile, age, beard, hair style/color, glasses, and makeup. | |
| Tip: For Beard/Goatee, negative strength tends to add facial hair. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(type="pil", label="Input face", sources=["upload", "clipboard"]) | |
| attr = gr.Dropdown( | |
| choices=list(ATTRIBUTE_MAP.keys()), | |
| value="Smile", | |
| label="Attribute", | |
| ) | |
| strength = gr.Slider(-15, 15, value=5, step=0.01, label="Strength (p)") | |
| align_face = gr.Checkbox(value=False, label="Align face before editing") | |
| use_bg_mask = gr.Checkbox(value=False, label="Use background mask (reduce artifacts)") | |
| custom_text = gr.Textbox( | |
| value="", | |
| label="Custom text edit (StyleCLIP Global Mapper)", | |
| placeholder="styleclip_global_a face_a face with black hair_0.18", | |
| ) | |
| run_btn = gr.Button("Run edit") | |
| with gr.Column(): | |
| out = gr.Image(type="pil", label="Edited output") | |
| # Update slider range based on attribute selection | |
| def _on_attr_change(name: str): | |
| lo, hi = recommended_range(name) | |
| # Keep current value within new bounds | |
| new_val = max(lo, min(hi, strength.value if hasattr(strength, "value") else 0)) | |
| return gr.Slider.update(minimum=lo, maximum=hi, value=new_val) | |
| attr.change(_on_attr_change, inputs=attr, outputs=strength) | |
| run_btn.click( | |
| fn=run_edit, | |
| inputs=[inp, attr, strength, align_face, use_bg_mask, custom_text], | |
| outputs=out, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_ui() | |
| # On Spaces, the port/host are managed by the platform; run local defaults otherwise | |
| app.launch() | |