hf-models / config.py
DimasMP3
Rename app configuration file to avoid Space conflicts
b975c79
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
CONFIG_FILE = Path(__file__).with_name("app-config.json")
DEFAULT_LABELS: List[str] = [
"Heart",
"Oblong",
"Oval",
"Round",
"Square",
]
DEFAULT_ALLOWED_ORIGINS: List[str] = [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
DEFAULT_PORT = 7860
def _load_config_file() -> Dict[str, Any]:
if not CONFIG_FILE.exists():
return {}
try:
return json.loads(CONFIG_FILE.read_text())
except json.JSONDecodeError as exc:
raise ValueError(f"Konfigurasi JSON tidak valid: {CONFIG_FILE}") from exc
def _parse_labels(raw: Any) -> List[str]:
if raw is None:
return DEFAULT_LABELS.copy()
if isinstance(raw, str):
parts = [label.strip() for label in raw.split(",") if label.strip()]
return parts or DEFAULT_LABELS.copy()
if isinstance(raw, list):
filtered = [str(label).strip() for label in raw if str(label).strip()]
return filtered or DEFAULT_LABELS.copy()
return DEFAULT_LABELS.copy()
def _parse_origins(raw: Any) -> List[str]:
if raw is None:
return DEFAULT_ALLOWED_ORIGINS.copy()
if isinstance(raw, str):
parts = [origin.strip() for origin in raw.split(",") if origin.strip()]
return parts or DEFAULT_ALLOWED_ORIGINS.copy()
if isinstance(raw, list):
filtered = [str(origin).strip() for origin in raw if str(origin).strip()]
return filtered or DEFAULT_ALLOWED_ORIGINS.copy()
return DEFAULT_ALLOWED_ORIGINS.copy()
def _parse_bool(raw: Any, fallback: bool = False) -> bool:
if raw is None:
return fallback
if isinstance(raw, bool):
return raw
return str(raw).strip().lower() in {"1", "true", "yes", "on"}
def _parse_int(raw: Any, fallback: int) -> int:
if raw is None:
return fallback
try:
value = int(raw)
except (TypeError, ValueError):
return fallback
return value if value > 0 else fallback
def _resolve_model_path(file_value: Optional[str]) -> str:
env_value = os.environ.get("MODEL_PATH")
candidate = env_value or file_value or "model/best_model.keras"
path = Path(candidate)
if not path.is_absolute():
path = (CONFIG_FILE.parent / path).resolve()
return str(path)
@dataclass(frozen=True)
class Settings:
model_path: str
labels: List[str]
port: int
cors_allowed_origins: List[str]
gradio_auth_token: str | None
gradio_username: str | None
gradio_password: str | None
share: bool
_CONFIG_CACHE = _load_config_file()
@lru_cache(maxsize=1)
def get_settings() -> Settings:
labels_raw = os.environ.get("MODEL_LABELS")
origins_raw = os.environ.get("CORS_ALLOWED_ORIGINS")
share_raw: Any = os.environ.get("GRADIO_SHARE")
port_raw: Any = os.environ.get("PORT")
return Settings(
model_path=_resolve_model_path(_CONFIG_CACHE.get("model_path")),
labels=_parse_labels(labels_raw if labels_raw is not None else _CONFIG_CACHE.get("labels")),
port=_parse_int(port_raw if port_raw is not None else _CONFIG_CACHE.get("port"), DEFAULT_PORT),
cors_allowed_origins=_parse_origins(
origins_raw if origins_raw is not None else _CONFIG_CACHE.get("cors_allowed_origins")
),
gradio_auth_token=os.environ.get("GRADIO_AUTH_TOKEN") or _CONFIG_CACHE.get("gradio_auth_token"),
gradio_username=os.environ.get("GRADIO_USERNAME") or _CONFIG_CACHE.get("gradio_username"),
gradio_password=os.environ.get("GRADIO_PASSWORD") or _CONFIG_CACHE.get("gradio_password"),
share=_parse_bool(share_raw if share_raw is not None else _CONFIG_CACHE.get("share"), False),
)
settings = get_settings()