Gentags / app.py
JairoDanielMT's picture
add
26d9b55
from __future__ import annotations
import argparse
from pathlib import Path
from tagger_core.service import TaggerService
from tagger_infra.onnx_engine import OnnxTaggerEngine
from tagger_infra.csv_repo import WDTagsRepository
from tagger_ui.gradio_app import GradioUI, CSS # <-- importa CSS
def build_service(model_dir: Path) -> TaggerService:
model_path = model_dir / "model.onnx"
csv_path = model_dir / "selected_tags.csv"
if not model_path.exists():
raise FileNotFoundError(f"No existe: {model_path}")
if not csv_path.exists():
raise FileNotFoundError(f"No existe: {csv_path}")
engine = OnnxTaggerEngine(model_path)
repo = WDTagsRepository(csv_path)
return TaggerService(engine=engine, tags_repo=repo)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--model-dir", default="wdtagger_model", help="Carpeta con model.onnx + selected_tags.csv")
ap.add_argument("--host", default="0.0.0.0")
ap.add_argument("--port", type=int, default=7860)
ap.add_argument("--share", action="store_true")
args = ap.parse_args()
service = build_service(Path(args.model_dir))
ui = GradioUI(service).build()
# Gradio 6+: css en launch()
ui.launch(
server_name=args.host,
server_port=args.port,
inbrowser=True,
share=args.share,
css=CSS,
)
if __name__ == "__main__":
main()