Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| from tagger_core.service import TaggerService | |
| from tagger_infra.onnx_engine import OnnxTaggerEngine | |
| from tagger_infra.csv_repo import WDTagsRepository | |
| from tagger_ui.gradio_app import GradioUI, CSS # <-- importa CSS | |
| def build_service(model_dir: Path) -> TaggerService: | |
| model_path = model_dir / "model.onnx" | |
| csv_path = model_dir / "selected_tags.csv" | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f"No existe: {model_path}") | |
| if not csv_path.exists(): | |
| raise FileNotFoundError(f"No existe: {csv_path}") | |
| engine = OnnxTaggerEngine(model_path) | |
| repo = WDTagsRepository(csv_path) | |
| return TaggerService(engine=engine, tags_repo=repo) | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--model-dir", default="wdtagger_model", help="Carpeta con model.onnx + selected_tags.csv") | |
| ap.add_argument("--host", default="0.0.0.0") | |
| ap.add_argument("--port", type=int, default=7860) | |
| ap.add_argument("--share", action="store_true") | |
| args = ap.parse_args() | |
| service = build_service(Path(args.model_dir)) | |
| ui = GradioUI(service).build() | |
| # Gradio 6+: css en launch() | |
| ui.launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| inbrowser=True, | |
| share=args.share, | |
| css=CSS, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |