File size: 1,951 Bytes
63089c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import annotations

import os
from typing import Iterable, Optional


HF_TOKEN_ENV_NAMES = (
    "hf_key",                    # RunPod env name used by this project
    "HF_TOKEN",                  # Hugging Face standard
    "HUGGINGFACE_HUB_TOKEN",
    "HUGGING_FACE_HUB_TOKEN",
    "HUGGINGFACE_TOKEN",
    "HF_API_TOKEN",
)

HF_REPO_ENV_NAMES = (
    "hf_repo",                   # RunPod env name used by this project
    "HF_REPO",
    "HF_REPO_ID",
    "HUGGINGFACE_REPO",
    "HUGGINGFACE_REPO_ID",
)


def first_env(names: Iterable[str]) -> Optional[str]:
    """Return the first non-empty environment variable value from names."""
    for name in names:
        value = os.environ.get(name)
        if value and value.strip():
            return value.strip()
    return None


def get_hf_token() -> Optional[str]:
    """Read a Hugging Face token from common env names.

    RunPod users can set `hf_key=hf_...`. This helper maps that to the token
    argument used by `transformers` and `huggingface_hub` without printing it.
    """
    return first_env(HF_TOKEN_ENV_NAMES)


def normalize_repo_id(repo_id_or_url: str) -> str:
    """Accept `shiowo/DINO-Protomorph` or full HF URLs and return a repo_id."""
    value = repo_id_or_url.strip()
    prefixes = (
        "https://huggingface.co/",
        "http://huggingface.co/",
        "huggingface.co/",
    )
    for prefix in prefixes:
        if value.startswith(prefix):
            value = value[len(prefix):]
            break
    value = value.strip("/")
    if value.startswith("models/"):
        value = value[len("models/"):]
    if "/tree/" in value:
        value = value.split("/tree/", 1)[0]
    if "/blob/" in value:
        value = value.split("/blob/", 1)[0]
    return value


def get_hf_repo_id(default: Optional[str] = None) -> Optional[str]:
    value = first_env(HF_REPO_ENV_NAMES) or default
    return normalize_repo_id(value) if value else None