my-mnist-hf / scripts /export_to_hub.py
dsaint31's picture
Release custom MNIST model
fab639f verified
# hf_custom_proj/scripts/export_to_hub.py
"""
๊ฐœ๋ฐœ ํ”„๋กœ์ ํŠธ(hf_custom_proj/)์—์„œ ํ•™์Šต ์‚ฐ์ถœ๋ฌผ์€ artifacts/์— ๋ณด๊ด€ํ•˜๊ณ ,
๋ฐฐํฌ(Hub/๋กœ์ปฌ from_pretrained)์šฉ repo ๋ฃจํŠธ๋Š” dist/my-mnist-hf/ ๋กœ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.
๋ชฉํ‘œ (๋‹น์‹ ์ด ์š”๊ตฌํ•œ "HF ์Šคํƒ€์ผ"์„ ๋งŒ์กฑ)
- ๊ฐœ๋ฐœ: src/ ๋ ˆ์ด์•„์›ƒ ์œ ์ง€ (pip install -e . / python -m examples.*)
- ๋ฐฐํฌ: dist/ ๋Š” Hub / local from_pretrained ๋ชจ๋‘ ๋™์ž‘
- ์ž„์‹œ์ฒ˜๋ฐฉ X: dist ๊ตฌ์กฐ๋ฅผ HF dynamic module ๋กœ๋”ฉ ๊ทœ์น™์— "์ •ํ•ฉ"๋˜๊ฒŒ ์ƒ์„ฑ
์ค‘์š” ํฌ์ธํŠธ (ํ˜„์žฌ ๋‹น์‹  ์ƒํ™ฉ์˜ ํ•ต์‹ฌ)
- preprocessor_config.json์˜ auto_map์ด
"AutoImageProcessor": "image_processing_my_mnist.MyMNISTImageProcessor"
์ฒ˜๋Ÿผ "ํŒจํ‚ค์ง€๋ช… ์—†์ด" ์ €์žฅ๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ์žˆ์Œ.
- ์ด ๊ฒฝ์šฐ transformers๋Š” dist ๋ฃจํŠธ์—์„œ image_processing_my_mnist.py ๋ฅผ ์ฐพ๋Š”๋‹ค.
- ๋”ฐ๋ผ์„œ dist ๋ฃจํŠธ์— configuration/modeling/image_processing *.py ๋ฅผ flatํ•˜๊ฒŒ ๋‘”๋‹ค.
"""
from __future__ import annotations
import argparse
import shutil
import sys
from pathlib import Path
# -----------------------------
# path utils
# -----------------------------
def project_root_from_this_file() -> Path:
return Path(__file__).resolve().parents[1]
def rmtree_if_exists(p: Path) -> None:
if p.exists():
shutil.rmtree(p)
def copytree(src: Path, dst: Path) -> None:
if not src.exists():
raise FileNotFoundError(f"Missing source path: {src}")
if dst.exists():
shutil.rmtree(dst)
shutil.copytree(src, dst)
def copy2(src: Path, dst: Path) -> None:
if not src.exists():
raise FileNotFoundError(f"Missing file: {src}")
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
def copy_if_exists(src: Path, dst: Path) -> None:
if src.exists():
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
def write_text(dst: Path, text: str) -> None:
dst.parent.mkdir(parents=True, exist_ok=True)
dst.write_text(text, encoding="utf-8")
# -----------------------------
# artifacts validate
# -----------------------------
def find_model_file(artifact_dir: Path) -> Path:
st = artifact_dir / "model.safetensors"
pt = artifact_dir / "pytorch_model.bin"
if st.exists():
return st
if pt.exists():
return pt
raise FileNotFoundError(
f"Missing model file in artifacts: expected '{st.name}' or '{pt.name}'"
)
def validate_artifacts(artifact_dir: Path) -> tuple[Path, Path, Path]:
if not artifact_dir.exists():
raise FileNotFoundError(
f"Artifacts directory does not exist: {artifact_dir}\n"
"๋จผ์ € ํ•™์Šต ํ›„ ์•„๋ž˜์ฒ˜๋Ÿผ ์ €์žฅํ•˜์„ธ์š”:\n"
" save_dir = 'artifacts/my_mnist'\n"
" trainer.save_model(save_dir)\n"
" processor.save_pretrained(save_dir)\n"
)
config = artifact_dir / "config.json"
preproc = artifact_dir / "preprocessor_config.json"
model = find_model_file(artifact_dir)
missing = []
if not config.exists():
missing.append("config.json")
if not preproc.exists():
missing.append("preprocessor_config.json")
if missing:
raise FileNotFoundError(
"Artifacts ํด๋”์— ํ•™์Šต ์‚ฐ์ถœ๋ฌผ์ด ์—†์Šต๋‹ˆ๋‹ค(๋˜๋Š” ๋ถˆ์™„์ „ํ•ฉ๋‹ˆ๋‹ค).\n"
f" artifacts: {artifact_dir}\n"
" missing:\n"
+ "\n".join([f" - {m}" for m in missing])
+ "\n\n"
"๋จผ์ € ํ•™์Šต ํ›„ ์•„๋ž˜์ฒ˜๋Ÿผ ์ €์žฅํ•˜์„ธ์š”:\n"
" save_dir = 'artifacts/my_mnist'\n"
" trainer.save_model(save_dir)\n"
" processor.save_pretrained(save_dir)\n"
)
return config, model, preproc
# -----------------------------
# dist build
# -----------------------------
def prepare_dist_repo_root(
*,
root: Path,
dist_repo_dir: Path,
artifact_dir: Path,
package_name: str,
copy_scripts_into_dist: bool,
requirements_text: str | None,
) -> None:
"""
dist repo ๋ฃจํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
- src/<package_name>/ ์•ˆ์˜ "๋ฐฐํฌ์— ํ•„์š”ํ•œ ๋ชจ๋“ˆ ํŒŒ์ผ"์„ dist ๋ฃจํŠธ๋กœ flat ๋ณต์‚ฌ
(ํ˜„์žฌ auto_map์ด 'image_processing_my_mnist.MyMNISTImageProcessor' ํ˜•ํƒœ๋กœ ์ €์žฅ๋˜๋Š” ์ƒํ™ฉ์„
๊ฐ€์žฅ ํ™•์‹คํžˆ ๋งŒ์กฑ์‹œํ‚ค๋Š” ๋ฐฉ์‹)
- examples/๋Š” dist์— ๊ฐ™์ด ๋„ฃ์–ด ํ…Œ์ŠคํŠธ ํŽธ์˜ ์ œ๊ณต(์›์น˜ ์•Š์œผ๋ฉด ์ง€์›Œ๋„ ๋จ)
- artifacts์˜ config/model/preprocessor_config๋ฅผ dist ๋ฃจํŠธ๋กœ ๋ณต์‚ฌ
"""
rmtree_if_exists(dist_repo_dir)
dist_repo_dir.mkdir(parents=True, exist_ok=True)
# 1) ์ฝ”๋“œ: src/<package_name> ์—์„œ ํ•„์š”ํ•œ py๋ฅผ dist ๋ฃจํŠธ๋กœ "flat" ๋ณต์‚ฌ
src_pkg = root / "src" / package_name
if not src_pkg.exists():
raise FileNotFoundError(f"Missing src package dir: {src_pkg}")
required_py = [
"configuration_my_mnist.py",
"modeling_my_mnist.py",
"image_processing_my_mnist.py",
]
for fname in required_py:
copy2(src_pkg / fname, dist_repo_dir / fname)
# __init__.py๋Š” ํ•„์ˆ˜๋Š” ์•„๋‹ˆ์ง€๋งŒ ๊ฐ™์ด ๋‘๋ฉด ๋””๋ฒ„๊น…์— ์œ ๋ฆฌ
if (src_pkg / "__init__.py").exists():
copy2(src_pkg / "__init__.py", dist_repo_dir / "__init__.py")
# 2) ์˜ˆ์ œ/์Šคํฌ๋ฆฝํŠธ ๋ณต์‚ฌ (ํŽธ์˜์šฉ)
copytree(root / "examples", dist_repo_dir / "examples")
if copy_scripts_into_dist:
# dist ์•ˆ์˜ scripts๋Š” "๋ฐฐํฌ๋ฌผ"์ด๋ผ๊ธฐ๋ณด๋‹ค "์ฐธ๊ณ ์šฉ"์ž…๋‹ˆ๋‹ค.
copytree(root / "scripts", dist_repo_dir / "scripts")
# 3) ๋ฉ”ํƒ€ ํŒŒ์ผ
copy_if_exists(root / "README.md", dist_repo_dir / "README.md")
copy_if_exists(root / "LICENSE", dist_repo_dir / "LICENSE")
copy_if_exists(root / ".gitignore", dist_repo_dir / ".gitignore")
copy_if_exists(root / "pyproject.toml", dist_repo_dir / "pyproject.toml")
# 4) requirements.txt
req_dst = dist_repo_dir / "requirements.txt"
if requirements_text is not None:
write_text(req_dst, requirements_text)
else:
if (root / "requirements.txt").exists():
copy2(root / "requirements.txt", req_dst)
else:
write_text(
req_dst,
"\n".join(["torch", "torchvision", "transformers", "evaluate", "numpy", "Pillow", ""]),
)
# 5) ํ•™์Šต ์‚ฐ์ถœ๋ฌผ ๋ณต์‚ฌ
config, model, preproc = validate_artifacts(artifact_dir)
copy2(config, dist_repo_dir / "config.json")
copy2(preproc, dist_repo_dir / "preprocessor_config.json")
copy2(model, dist_repo_dir / model.name)
# ์•ˆ๋‚ด ๋ฉ”๋ชจ(์„ ํƒ)
write_text(
dist_repo_dir / "DIST_NOTE.txt",
"\n".join(
[
"This folder is a Hub/local from_pretrained() compatible repo root.",
"Dynamic module loading expects flat *.py modules at repo root (per auto_map).",
"",
]
),
)
# -----------------------------
# Hub upload (optional)
# -----------------------------
def upload_to_hub(
*,
dist_repo_dir: Path,
repo_id: str,
private: bool,
repo_type: str,
commit_message: str,
) -> None:
try:
from huggingface_hub import HfApi
from huggingface_hub.utils import HfHubHTTPError
from huggingface_hub.hf_api import HfFolder
except Exception as e:
raise RuntimeError(
"huggingface_hub๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์„ค์น˜ ํ›„ ๋‹ค์‹œ ์‹คํ–‰ํ•˜์„ธ์š”:\n"
" pip install -U huggingface_hub\n"
) from e
token = HfFolder.get_token()
if not token:
raise RuntimeError(
"Hugging Face ์ธ์ฆ ํ† ํฐ์ด ์—†์Šต๋‹ˆ๋‹ค.\n"
"๋‹ค์Œ ์ค‘ ํ•˜๋‚˜๋ฅผ ์ˆ˜ํ–‰ํ•˜์„ธ์š”:\n"
" 1) ํ„ฐ๋ฏธ๋„์—์„œ: huggingface-cli login\n"
" 2) ํ™˜๊ฒฝ๋ณ€์ˆ˜๋กœ: export HF_TOKEN=... (Colab/CI ํฌํ•จ)\n"
" 3) ํŒŒ์ด์ฌ์—์„œ: from huggingface_hub import login; login('HF_TOKEN')\n"
)
api = HfApi()
try:
api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True, private=private)
except HfHubHTTPError as e:
raise RuntimeError(
f"Repo ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค: {repo_id} (repo_type={repo_type})\n"
"๊ถŒํ•œ(organization repo ์—ฌ๋ถ€), repo_id ์˜คํƒ€, ํ† ํฐ ๊ถŒํ•œ์„ ํ™•์ธํ•˜์„ธ์š”."
) from e
try:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=str(dist_repo_dir),
commit_message=commit_message,
)
except HfHubHTTPError as e:
raise RuntimeError(
f"์—…๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค: {repo_id} (repo_type={repo_type})\n"
"ํ† ํฐ ๊ถŒํ•œ, ๋Œ€์šฉ๋Ÿ‰ ํŒŒ์ผ, ๋„คํŠธ์›Œํฌ ์ƒํƒœ๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
) from e
# -----------------------------
# CLI
# -----------------------------
def build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="Build dist repo root (Hub-ready) from src/<pkg> + artifacts, and optionally push to Hub."
)
p.add_argument("--push", action="store_true", help="์ง€์ • ์‹œ Hub์— ์—…๋กœ๋“œ๊นŒ์ง€ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
p.add_argument("--repo-id", default=None, help='Hub repo id. ์˜ˆ: "YOUR_ID/my-mnist-hf" ( --push์ผ ๋•Œ ํ•„์ˆ˜ )')
p.add_argument("--private", action="store_true", help="Hub repo๋ฅผ private๋กœ ์ƒ์„ฑ( --push์ผ ๋•Œ๋งŒ ์˜๋ฏธ )")
p.add_argument("--repo-type", default="model", help='Hub repo type. ๊ธฐ๋ณธ: "model"')
p.add_argument("--commit-message", default="Release custom MNIST model", help="Hub ์ปค๋ฐ‹ ๋ฉ”์‹œ์ง€")
p.add_argument("--artifact-dir", default="artifacts/my_mnist", help="ํ•™์Šต ์‚ฐ์ถœ๋ฌผ ํด๋”. ๊ธฐ๋ณธ: artifacts/my_mnist")
p.add_argument("--dist-dir", default="dist/my-mnist-hf", help="๋ฐฐํฌ์šฉ dist ํด๋”. ๊ธฐ๋ณธ: dist/my-mnist-hf")
p.add_argument(
"--package-name",
default="my_mnist_hf",
help="src/ ์•„๋ž˜์˜ ํŒจํ‚ค์ง€ ํด๋”๋ช…. ๊ธฐ๋ณธ: my_mnist_hf",
)
p.add_argument("--no-copy-scripts", action="store_true", help="dist์— scripts/ ๋ณต์‚ฌ๋ฅผ ์ƒ๋žต")
p.add_argument(
"--requirements",
default=None,
help="requirements.txt ๋‚ด์šฉ์„ ๋ฌธ์ž์—ด๋กœ ์ง€์ •(ํŒŒ์ผ ์ƒ์„ฑ). ๋ฏธ์ง€์ •์ด๋ฉด ๋ฃจํŠธ requirements.txt ๋ณต์‚ฌ ๋˜๋Š” ๊ธฐ๋ณธ๊ฐ’ ์ƒ์„ฑ.",
)
return p
def main() -> None:
args = build_argparser().parse_args()
root = project_root_from_this_file()
artifact_dir = root / args.artifact_dir
dist_repo_dir = root / args.dist_dir
prepare_dist_repo_root(
root=root,
dist_repo_dir=dist_repo_dir,
artifact_dir=artifact_dir,
package_name=str(args.package_name),
copy_scripts_into_dist=not bool(args.no_copy_scripts),
requirements_text=args.requirements,
)
print("\n[OK] dist ํด๋” ์ƒ์„ฑ ์™„๋ฃŒ")
print(f" - artifacts : {artifact_dir}")
print(f" - dist repo : {dist_repo_dir}")
print(f" - package : {args.package_name}")
if not args.push:
print("\n[Info] --push๋ฅผ ์ง€์ •ํ•˜์ง€ ์•Š์•„ Hub ์—…๋กœ๋“œ๋Š” ์ƒ๋žตํ–ˆ์Šต๋‹ˆ๋‹ค(local-only).")
print("\n๋กœ์ปฌ ๋กœ๋“œ ์˜ˆ์‹œ:")
print(" from transformers import AutoModelForImageClassification, AutoImageProcessor")
print(f" p = '{args.dist_dir}'")
print(" processor = AutoImageProcessor.from_pretrained(p, trust_remote_code=True)")
print(" model = AutoModelForImageClassification.from_pretrained(p, trust_remote_code=True)")
return
if not args.repo_id:
raise ValueError("--push๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด --repo-id๋ฅผ ๋ฐ˜๋“œ์‹œ ์ง€์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
upload_to_hub(
dist_repo_dir=dist_repo_dir,
repo_id=str(args.repo_id),
private=bool(args.private),
repo_type=str(args.repo_type),
commit_message=str(args.commit_message),
)
print("\n[OK] Hub ์—…๋กœ๋“œ ์™„๋ฃŒ")
print(f" - hub repo : {args.repo_id}")
print("\nHub ๋กœ๋“œ ์˜ˆ์‹œ:")
print(" from transformers import AutoModelForImageClassification, AutoImageProcessor")
print(f" repo_id = '{args.repo_id}'")
print(" processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True)")
print(" model = AutoModelForImageClassification.from_pretrained(repo_id, trust_remote_code=True)")
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"\n์˜ค๋ฅ˜: {e}\n", file=sys.stderr)
raise