Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import logging | |
| from typing import Tuple, Dict | |
| import gradio as gr | |
| from fastapi import FastAPI, UploadFile, File, Form, Header, HTTPException, Depends | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.testclient import TestClient | |
| import io | |
| from spaces import GPU | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| import json | |
| try: | |
| import firebase_admin | |
| from firebase_admin import credentials, auth as fb_auth | |
| except Exception: # firebase optional; enabled when installed | |
| firebase_admin = None | |
| credentials = None | |
| fb_auth = None | |
| FIREBASE_APP = None | |
| def _init_firebase_if_possible() -> None: | |
| global FIREBASE_APP | |
| if FIREBASE_APP is not None: | |
| return | |
| if firebase_admin is None: | |
| logger.info("firebase-admin not installed; skipping Firebase init") | |
| return | |
| # Service account via env var JSON or file path | |
| sa_env = os.getenv("FIREBASE_CREDENTIALS_JSON", "").strip() | |
| sa_path = "firebase_service_account.json" | |
| try: | |
| cred_obj = None | |
| if sa_env: | |
| # Allow raw JSON or file path | |
| if os.path.exists(sa_env): | |
| cred_obj = credentials.Certificate(sa_env) | |
| else: | |
| cred_obj = credentials.Certificate(json.loads(sa_env)) | |
| elif os.path.exists(sa_path): | |
| cred_obj = credentials.Certificate(sa_path) | |
| if cred_obj is not None: | |
| FIREBASE_APP = firebase_admin.initialize_app(cred_obj) | |
| logger.info("Firebase initialized successfully") | |
| else: | |
| logger.info("No Firebase credentials provided; skipping Firebase init") | |
| except Exception as e: | |
| logger.warning("Firebase init failed: %s", e) | |
| FIREBASE_APP = None | |
| # Configure environment BEFORE importing any torch-dependent modules | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") | |
| os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "8.0") | |
| from runners.simple_runner import SimpleRunner | |
| # ----------------------------------------------------------------------------- | |
| # Logging (use lazy % formatting as requested) | |
| # ----------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("sfe-app") | |
| # ----------------------------------------------------------------------------- | |
| # Model bootstrap (load once and reuse) | |
| # ----------------------------------------------------------------------------- | |
| RUNNER: SimpleRunner | None = None | |
| def ensure_weights(): | |
| """Make sure pretrained weights exist locally; otherwise fetch from your HF model repo.""" | |
| need = [ | |
| "pretrained_models/sfe_editor_light.pt", | |
| "pretrained_models/stylegan2-ffhq-config-f.pt", | |
| "pretrained_models/e4e_ffhq_encode.pt", | |
| "pretrained_models/stylegan2-ffhq-config-f.pkl", | |
| "pretrained_models/shape_predictor_68_face_landmarks.dat", | |
| "pretrained_models/fs3.npy", | |
| "pretrained_models/delta_mapper.pt", | |
| "pretrained_models/iresnet50-7f187506.pth", | |
| "pretrained_models/model_ir_se50.pth", | |
| "pretrained_models/CurricularFace_Backbone.pth", | |
| "pretrained_models/face_parsing.farl.lapa.main_ema_136500_jit191.pt", | |
| "pretrained_models/mobilenet0.25_Final.pth", | |
| "pretrained_models/moco_v2_800ep_pretrain.pt", | |
| "pretrained_models/79999_iter.pth", | |
| ] | |
| # Check if any of the needed files exist | |
| files_exist = any(os.path.exists(p) for p in need) | |
| if files_exist: | |
| logger.info("Some weights already exist, skipping download") | |
| return | |
| repo_id = "LogicGoInfotechSpaces/Smile_Changer_pre_model" | |
| logger.info("Missing weights; downloading snapshot from %s", repo_id) | |
| try: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=".", | |
| allow_patterns=["**/*"], | |
| token=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN"), | |
| ) | |
| logger.info("Download completed successfully") | |
| except Exception as e: | |
| logger.error("Download failed: %s", e) | |
| return | |
| # Add a small delay to ensure files are fully written | |
| import time | |
| time.sleep(3) | |
| # Debug: List all files in pretrained_models directory | |
| if os.path.exists("pretrained_models"): | |
| logger.info("Files in pretrained_models directory:") | |
| try: | |
| for root, dirs, files in os.walk("pretrained_models"): | |
| for file in files: | |
| full_path = os.path.join(root, file) | |
| logger.info(" %s (size: %d bytes)", full_path, os.path.getsize(full_path)) | |
| except Exception as e: | |
| logger.error("Error listing files: %s", e) | |
| else: | |
| logger.error("pretrained_models directory does not exist!") | |
| # Verify critical files exist | |
| for file_path in need: | |
| if not os.path.exists(file_path): | |
| logger.warning("File %s still not found after download", file_path) | |
| else: | |
| logger.info("File %s found successfully", file_path) | |
| def get_runner() -> SimpleRunner: | |
| global RUNNER | |
| if RUNNER is None: | |
| logger.info("Getting runner - calling ensure_weights()") | |
| ensure_weights() | |
| logger.info("Initializing SimpleRunner with %s", "pretrained_models/sfe_editor_light.pt") | |
| RUNNER = SimpleRunner( | |
| editor_ckpt_pth="pretrained_models/sfe_editor_light.pt", | |
| ) | |
| logger.info("SimpleRunner initialized successfully") | |
| return RUNNER | |
| # ----------------------------------------------------------------------------- | |
| # Attribute catalog and recommended ranges | |
| # ----------------------------------------------------------------------------- | |
| # Each entry maps a friendly attribute name to the internal editing name and a | |
| # recommended power range for the slider. | |
| ATTRIBUTE_MAP: Dict[str, Tuple[str, Tuple[float, float]]] = { | |
| # Face semantics | |
| "Smile": ("fs_smiling", (-10.0, 10.0)), | |
| "Age": ("age", (-10.0, 10.0)), # interfacegan_directions | |
| "Female features": ("gender", (-10.0, 7.0)), # stylespace_directions (positive adds femininity) | |
| # Facial hair | |
| # trimmed_beard removes beard for positive power; use negative to add | |
| "Beard": ("trimmed_beard", (-30.0, 30.0)), # Negative values ADD beard | |
| # goatee removes goatee for positive; negative tends to add | |
| "Mustache/Goatee": ("goatee", (-7.0, 7.0)), # Negative values ADD goatee | |
| # Accessories & cosmetics | |
| "Glasses": ("fs_glasses", (-20.0, 30.0)), | |
| "Makeup": ("fs_makeup", (-10.0, 15.0)), | |
| # Hair style (pretrained mappers) | |
| "Curly hair": ("curly_hair", (0.0, 0.12)), # styleclip_directions | |
| "Afro": ("afro", (0.0, 0.14)), | |
| # Hair color via global text mapper | |
| # You can also type custom prompts below | |
| "Orange hair (text)": ("styleclip_global_a face_a face with orange hair_0.18", (0.0, 0.2)), | |
| "Blonde hair (text)": ("styleclip_global_a face_a face with blonde hair_0.18", (0.0, 0.2)), | |
| } | |
| def recommended_range(attr_name: str) -> Tuple[float, float]: | |
| edit_name, rng = ATTRIBUTE_MAP[attr_name] | |
| return rng | |
| def run_edit( | |
| image: Image.Image, | |
| attribute: str, | |
| strength: float, | |
| align_face: bool, | |
| use_bg_mask: bool, | |
| custom_text_edit: str, | |
| ) -> Image.Image: | |
| """Run a single attribute edit and return the edited image.""" | |
| runner = get_runner() | |
| # Determine editing name and clip strength into the suggested range | |
| edit_name, (lo, hi) = ATTRIBUTE_MAP[attribute] | |
| if custom_text_edit and attribute.endswith("(text)"): | |
| # Allow overriding the default text prompt | |
| if custom_text_edit.strip(): | |
| edit_name = custom_text_edit.strip() | |
| clipped_strength = max(lo, min(hi, strength)) | |
| if clipped_strength != strength: | |
| logger.info("Clipped strength from %s to %s for %s", strength, clipped_strength, attribute) | |
| # Persist input to a temp file for the runner | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| inp_path = os.path.join(tmpdir, "input.jpg") | |
| out_path = os.path.join(tmpdir, "edited.jpg") | |
| image.convert("RGB").save(inp_path) | |
| logger.info("Editing %s with power %s", edit_name, clipped_strength) | |
| _ = runner.edit( | |
| orig_img_pth=inp_path, | |
| editing_name=edit_name, | |
| edited_power=clipped_strength, | |
| save_pth=out_path, | |
| align=align_face, | |
| use_mask=use_bg_mask, | |
| ) | |
| return Image.open(out_path).convert("RGB") | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
| gr.Markdown(""" | |
| **StyleFeatureEditor – Facial Attribute Editing** | |
| Upload a face and apply edits like smile, age, beard, hair style/color, glasses, and makeup. | |
| **Tips:** | |
| - **Beard/Goatee**: Use **negative values** to ADD facial hair, positive values to remove | |
| - **Smile**: Positive values add smile, negative values remove smile | |
| - **Age**: Positive values make older, negative values make younger | |
| - **Glasses**: Positive values add glasses, negative values remove glasses | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(type="pil", label="Input face", sources=["upload", "clipboard"]) | |
| attr = gr.Dropdown( | |
| choices=list(ATTRIBUTE_MAP.keys()), | |
| value="Smile", | |
| label="Attribute", | |
| ) | |
| strength = gr.Slider(-15, 15, value=5, step=0.01, label="Strength (p)") | |
| align_face = gr.Checkbox(value=True, label="Align face before editing") | |
| use_bg_mask = gr.Checkbox(value=False, label="Use background mask (reduce artifacts)") | |
| custom_text = gr.Textbox( | |
| value="", | |
| label="Custom text edit (StyleCLIP Global Mapper)", | |
| placeholder="styleclip_global_a face_a face with black hair_0.18", | |
| ) | |
| run_btn = gr.Button("Run edit") | |
| with gr.Column(): | |
| out = gr.Image(type="pil", label="Edited output") | |
| # Update slider range based on attribute selection | |
| def _on_attr_change(name: str): | |
| lo, hi = recommended_range(name) | |
| # Keep current value within new bounds | |
| new_val = max(lo, min(hi, strength.value if hasattr(strength, "value") else 0)) | |
| return gr.Slider(minimum=lo, maximum=hi, value=new_val) | |
| attr.change(_on_attr_change, inputs=attr, outputs=strength) | |
| run_btn.click( | |
| fn=run_edit, | |
| inputs=[inp, attr, strength, align_face, use_bg_mask, custom_text], | |
| outputs=out, | |
| ) | |
| return demo | |
| # Build Gradio UI | |
| demo = build_ui() | |
| # ----------------------------- | |
| # REST API (FastAPI) endpoints | |
| # ----------------------------- | |
| api = FastAPI(title="Smile Changer API") | |
| def _require_auth(authorization: str | None = Header(default=None)): | |
| """Accepts either a static Bearer token (API_AUTH_TOKEN) or a Firebase ID token. | |
| Returns a dict of auth info if authenticated; raises 401 otherwise. | |
| """ | |
| expected = os.getenv("API_AUTH_TOKEN", "logicgo_123") | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") | |
| token = authorization.split(" ", 1)[1] | |
| # Static token fallback | |
| if token == expected: | |
| return {"auth": "static"} | |
| # Firebase ID token verification (if configured) | |
| _init_firebase_if_possible() | |
| if firebase_admin is not None and fb_auth is not None and FIREBASE_APP is not None: | |
| try: | |
| claims = fb_auth.verify_id_token(token) | |
| return {"auth": "firebase", "claims": claims, "uid": claims.get("uid")} | |
| except Exception as e: | |
| logger.warning("Firebase token verification failed: %s", e) | |
| # If reached here, reject | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| def root_index(): | |
| return { | |
| "name": "Smile Changer API", | |
| "status": "ok", | |
| "ui": "/app", | |
| "endpoints": { | |
| "GET /health": "public health", | |
| "GET /api/health": "public health (alias)", | |
| "GET /api/ping": "auth check", | |
| "GET /api/attributes": "list attributes", | |
| "POST /api/edit": "generic edit", | |
| "POST /api/edit/{attribute}": "edit by attribute name", | |
| }, | |
| "auth": "set API_AUTH_TOKEN to require Authorization: Bearer <token> (except /health)", | |
| } | |
| def health_root(): | |
| return {"status": "ok"} | |
| def list_attributes(_: None = Depends(_require_auth)): | |
| items = {} | |
| for k, v in ATTRIBUTE_MAP.items(): | |
| edit_name, (lo, hi) = v | |
| items[k] = {"internal": edit_name, "min": lo, "max": hi} | |
| return JSONResponse(items) | |
| def health(): | |
| return {"status": "ok"} | |
| def ping(_: None = Depends(_require_auth)): | |
| return {"status": "ok", "auth": True} | |
| def me(user=Depends(_require_auth)): | |
| # Returns auth mode and (if Firebase) user claims/uid | |
| info = {"mode": user.get("auth")} | |
| if user.get("auth") == "firebase": | |
| info["uid"] = user.get("uid") | |
| # Avoid returning all claims by default; include subset | |
| claims = user.get("claims", {}) | |
| basic = {k: claims.get(k) for k in ("email", "name", "picture", "user_id", "uid") if claims.get(k) is not None} | |
| info["claims"] = basic | |
| return JSONResponse(info) | |
| def _self_check(): | |
| try: | |
| client = TestClient(api) | |
| r = client.get("/api/health") | |
| logger.info("Self-check /api/health -> %s %s", r.status_code, r.json() if r.headers.get("content-type"," ").startswith("application/json") else "") | |
| except Exception as e: | |
| logger.error("Self-check failed: %s", e) | |
| async def api_edit( | |
| file: UploadFile = File(...), | |
| attribute: str = Form(...), | |
| strength: float = Form(5.0), | |
| align_face: bool = Form(True), | |
| use_bg_mask: bool = Form(False), | |
| custom_text_edit: str = Form(""), | |
| _: None = Depends(_require_auth) | |
| ): | |
| data = await file.read() | |
| image = Image.open(io.BytesIO(data)).convert("RGB") | |
| result = run_edit( | |
| image=image, | |
| attribute=attribute, | |
| strength=strength, | |
| align_face=align_face, | |
| use_bg_mask=use_bg_mask, | |
| custom_text_edit=custom_text_edit, | |
| ) | |
| buf = io.BytesIO() | |
| result.save(buf, format="PNG") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| async def api_edit_by_attribute( | |
| attribute_name: str, | |
| file: UploadFile = File(...), | |
| strength: float = Form(5.0), | |
| align_face: bool = Form(True), | |
| use_bg_mask: bool = Form(False), | |
| custom_text_edit: str = Form(""), | |
| _: None = Depends(_require_auth) | |
| ): | |
| return await api_edit( | |
| file=file, | |
| attribute=attribute_name, | |
| strength=strength, | |
| align_face=align_face, | |
| use_bg_mask=use_bg_mask, | |
| custom_text_edit=custom_text_edit, | |
| ) | |
| # Convenience endpoints for each attribute | |
| def _register_attribute_endpoint(path: str, attribute_value: str): | |
| async def _endpoint( | |
| file: UploadFile = File(...), | |
| strength: float = Form(5.0), | |
| align_face: bool = Form(True), | |
| use_bg_mask: bool = Form(False), | |
| custom_text_edit: str = Form(""), | |
| _: None = Depends(_require_auth) | |
| ): | |
| return await api_edit( | |
| file=file, | |
| attribute=attribute_value, | |
| strength=strength, | |
| align_face=align_face, | |
| use_bg_mask=use_bg_mask, | |
| custom_text_edit=custom_text_edit, | |
| ) | |
| _register_attribute_endpoint("/api/smile", "Smile") | |
| _register_attribute_endpoint("/api/age", "Age") | |
| _register_attribute_endpoint("/api/female-features", "Female features") | |
| _register_attribute_endpoint("/api/beard", "Beard") | |
| _register_attribute_endpoint("/api/mustache-goatee", "Mustache/Goatee") | |
| _register_attribute_endpoint("/api/glasses", "Glasses") | |
| _register_attribute_endpoint("/api/makeup", "Makeup") | |
| _register_attribute_endpoint("/api/curly-hair", "Curly hair") | |
| _register_attribute_endpoint("/api/afro", "Afro") | |
| _register_attribute_endpoint("/api/orange-hair-text", "Orange hair (text)") | |
| _register_attribute_endpoint("/api/blonde-hair-text", "Blonde hair (text)") | |
| async def api_image_edit( | |
| file: UploadFile = File(...), | |
| attribute: str = Form("Smile"), | |
| strength: float = Form(5.0), | |
| align_face: bool = Form(False), | |
| use_bg_mask: bool = Form(False), | |
| custom_text_edit: str = Form("") | |
| ): | |
| data = await file.read() | |
| image = Image.open(io.BytesIO(data)).convert("RGB") | |
| result = run_edit( | |
| image=image, | |
| attribute=attribute, | |
| strength=strength, | |
| align_face=align_face, | |
| use_bg_mask=use_bg_mask, | |
| custom_text_edit=custom_text_edit | |
| ) | |
| buf = io.BytesIO() | |
| result.save(buf, format="PNG") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| # Mount Gradio under /app and expose FastAPI at root for clean API base | |
| app = gr.mount_gradio_app(api, demo, path="/app") | |
| def _warmup_gpu(): | |
| # CPU-only Space; this is a no-op to satisfy GPU startup checks | |
| return "ok" | |
| if __name__ == "__main__": | |
| # Local run. On Spaces, the platform serves the FastAPI app automatically. | |
| try: | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| except Exception as e: | |
| print("Failed to start uvicorn:", e) | |