File size: 15,310 Bytes
a9bd396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
import argparse
import os
import re
import subprocess
from datetime import date, datetime
from urllib.error import HTTPError
from urllib.request import Request, urlopen

from huggingface_hub import paper_info


ROOT = os.getcwd().split("utils")[0]
DOCS_PATH = os.path.join(ROOT, "docs/source/en/model_doc")
MODELS_PATH = os.path.join(ROOT, "src/transformers/models")
GITHUB_REPO_URL = "https://github.com/huggingface/transformers"
GITHUB_RAW_URL = "https://raw.githubusercontent.com/huggingface/transformers/main"

COPYRIGHT_DISCLAIMER = """<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->"""

ARXIV_PAPERS_NOT_IN_HF_PAPERS = {
    "gemma3n.md": "2506.06644",
    "xmod.md": "2205.06266",
}


def check_file_exists_on_github(file_path: str) -> bool:
    """Check if a file exists on the main branch of the GitHub repository.

    Args:
        file_path: Relative path from repository root

    Returns:
        True if file exists on GitHub main branch (or if check failed), False only if confirmed 404

    Note:
        On network errors or other issues, returns True (assumes file exists) with a warning.
        This prevents the script from failing due to temporary network issues.
    """
    # Convert absolute path to relative path from repository root if needed
    if file_path.startswith(ROOT):
        file_path = file_path[len(ROOT) :].lstrip("/")

    # Construct the raw GitHub URL for the file
    url = f"{GITHUB_RAW_URL}/{file_path}"

    try:
        # Make a HEAD request to check if file exists (more efficient than GET)
        request = Request(url, method="HEAD")
        request.add_header("User-Agent", "transformers-add-dates-script")

        with urlopen(request, timeout=10) as response:
            return response.status == 200
    except HTTPError as e:
        if e.code == 404:
            # File doesn't exist on GitHub
            return False
        # HTTP error (non-404): assume file exists and continue with local git history
        return True
    except Exception:
        # Network/timeout error: assume file exists and continue with local git history
        return True


def get_modified_cards() -> list[str]:
    """Get the list of model names from modified files in docs/source/en/model_doc/"""

    current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
    if current_branch == "main":
        # On main branch, only uncommitted changes detected
        result = subprocess.check_output(["git", "diff", "--name-only", "HEAD"], text=True)
    else:
        fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
        result = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8")

    model_names = []
    for line in result.strip().split("\n"):
        if line:
            # Check if the file is in the model_doc directory
            if line.startswith("docs/source/en/model_doc/") and line.endswith(".md"):
                file_path = os.path.join(ROOT, line)
                if os.path.exists(file_path):
                    model_name = os.path.splitext(os.path.basename(line))[0]
                    if model_name not in ["auto", "timm_wrapper"]:
                        model_names.append(model_name)

    return model_names


def get_paper_link(model_card: str | None, path: str | None) -> str:
    """Get the first paper link from the model card content."""

    if model_card is not None and not model_card.endswith(".md"):
        model_card = f"{model_card}.md"
    file_path = path or os.path.join(DOCS_PATH, f"{model_card}")
    model_card = os.path.basename(file_path)
    with open(file_path, "r", encoding="utf-8") as f:
        content = f.read()

    # Find known paper links
    paper_ids = re.findall(r"https://huggingface\.co/papers/\d+\.\d+", content)
    paper_ids += re.findall(r"https://arxiv\.org/abs/\d+\.\d+", content)
    paper_ids += re.findall(r"https://arxiv\.org/pdf/\d+\.\d+", content)

    if len(paper_ids) == 0:
        return "No_paper"

    return paper_ids[0]


def get_first_commit_date(model_name: str | None) -> str:
    """Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""

    if model_name.endswith(".md"):
        model_name = f"{model_name[:-3]}"

    model_name_src = model_name
    if "-" in model_name:
        model_name_src = model_name.replace("-", "_")
    file_path = os.path.join(MODELS_PATH, model_name_src, "__init__.py")

    # If the init file is not found (only true for legacy models), the doc's first commit date is used
    if not os.path.exists(file_path):
        file_path = os.path.join(DOCS_PATH, f"{model_name}.md")

    # Check if file exists on GitHub main branch
    file_exists_on_github = check_file_exists_on_github(file_path)

    if not file_exists_on_github:
        # File does not exist on GitHub main branch (new model), use today's date
        final_date = date.today().isoformat()
    else:
        # File exists on GitHub main branch, get the first commit date from local git history
        final_date = subprocess.check_output(
            ["git", "log", "--reverse", "--pretty=format:%ad", "--date=iso", file_path], text=True
        )
    return final_date.strip().split("\n")[0][:10]


def get_release_date(link: str) -> str:
    if link.startswith("https://huggingface.co/papers/"):
        link = link.replace("https://huggingface.co/papers/", "")

        try:
            info = paper_info(link)
            return info.published_at.date().isoformat()
        except Exception:
            # Error fetching release date, function returns None (will use placeholder)
            pass

    elif link.startswith("https://arxiv.org/abs/") or link.startswith("https://arxiv.org/pdf/"):
        return r"{release_date}"


def replace_paper_links(file_path: str) -> bool:
    """Replace arxiv links with huggingface links if valid, and replace hf.co with huggingface.co"""

    with open(file_path, "r", encoding="utf-8") as f:
        content = f.read()

    original_content = content

    # Replace hf.co with huggingface.co
    content = content.replace("https://hf.co/", "https://huggingface.co/")

    # Find all arxiv links
    arxiv_links = re.findall(r"https://arxiv\.org/abs/(\d+\.\d+)", content)
    arxiv_links += re.findall(r"https://arxiv\.org/pdf/(\d+\.\d+)", content)

    for paper_id in arxiv_links:
        try:
            # Check if paper exists on huggingface
            paper_info(paper_id)
            # If no exception, replace the link
            old_link = f"https://arxiv.org/abs/{paper_id}"
            if old_link not in content:
                old_link = f"https://arxiv.org/pdf/{paper_id}"
            new_link = f"https://huggingface.co/papers/{paper_id}"
            content = content.replace(old_link, new_link)

        except Exception:
            # Paper not available on huggingface, keep arxiv link
            continue

    # Write back only if content changed
    if content != original_content:
        with open(file_path, "w", encoding="utf-8") as f:
            f.write(content)
        return True
    return False


def _normalize_model_card_name(model_card: str) -> str:
    """Ensure model card has .md extension"""
    return model_card if model_card.endswith(".md") else f"{model_card}.md"


def _should_skip_model_card(model_card: str) -> bool:
    """Check if model card should be skipped"""
    return model_card in ("auto.md", "timm_wrapper.md")


def _read_model_card_content(model_card: str) -> str:
    """Read and return the content of a model card"""
    file_path = os.path.join(DOCS_PATH, model_card)
    with open(file_path, "r", encoding="utf-8") as f:
        return f.read()


def _get_dates_pattern_match(content: str):
    """Search for the dates pattern in content and return match object"""
    pattern = r"\n\*This model was released on (.*) and added to Hugging Face Transformers on (\d{4}-\d{2}-\d{2})\.\*"
    return re.search(pattern, content)


def _dates_differ_significantly(date1: str, date2: str) -> bool:
    """Check if two dates differ by more than 1 day"""
    try:
        d1 = datetime.strptime(date1, "%Y-%m-%d")
        d2 = datetime.strptime(date2, "%Y-%m-%d")
        return abs((d1 - d2).days) > 1
    except Exception:
        return True  # If dates can't be parsed, consider them different


def check_missing_dates(model_card_list: list[str]) -> list[str]:
    """Check which model cards are missing release dates and return their names"""
    missing_dates = []

    for model_card in model_card_list:
        model_card = _normalize_model_card_name(model_card)
        if _should_skip_model_card(model_card):
            continue

        content = _read_model_card_content(model_card)
        if not _get_dates_pattern_match(content):
            missing_dates.append(model_card)

    return missing_dates


def check_incorrect_dates(model_card_list: list[str]) -> list[str]:
    """Check which model cards have incorrect HF commit dates and return their names"""
    incorrect_dates = []

    for model_card in model_card_list:
        model_card = _normalize_model_card_name(model_card)
        if _should_skip_model_card(model_card):
            continue

        content = _read_model_card_content(model_card)
        match = _get_dates_pattern_match(content)

        if match:
            existing_hf_date = match.group(2)
            actual_hf_date = get_first_commit_date(model_name=model_card)

            if _dates_differ_significantly(existing_hf_date, actual_hf_date):
                incorrect_dates.append(model_card)

    return incorrect_dates


def insert_dates(model_card_list: list[str]):
    """Insert or update release and commit dates in model cards"""
    for model_card in model_card_list:
        model_card = _normalize_model_card_name(model_card)
        if _should_skip_model_card(model_card):
            continue

        file_path = os.path.join(DOCS_PATH, model_card)

        # First replace arxiv paper links with hf paper link if possible
        replace_paper_links(file_path)

        # Read content and ensure copyright disclaimer exists
        content = _read_model_card_content(model_card)
        markers = list(re.finditer(r"-->", content))

        if len(markers) == 0:
            # No copyright marker found, adding disclaimer to the top
            content = COPYRIGHT_DISCLAIMER + "\n\n" + content
            with open(file_path, "w", encoding="utf-8") as f:
                f.write(content)
            markers = list(re.finditer(r"-->", content))

        # Get dates
        hf_commit_date = get_first_commit_date(model_name=model_card)
        paper_link = get_paper_link(model_card=model_card, path=file_path)

        if paper_link in ("No_paper", "blog"):
            release_date = r"{release_date}"
        else:
            release_date = get_release_date(paper_link)

        match = _get_dates_pattern_match(content)

        # Update or insert the dates line
        if match:
            # Preserve existing release date unless it's a placeholder
            existing_release_date = match.group(1)
            existing_hf_date = match.group(2)

            if existing_release_date not in (r"{release_date}", "None"):
                release_date = existing_release_date

            if _dates_differ_significantly(existing_hf_date, hf_commit_date) or existing_release_date != release_date:
                old_line = match.group(0)
                new_line = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*"
                content = content.replace(old_line, new_line)
                with open(file_path, "w", encoding="utf-8") as f:
                    f.write(content)
        else:
            # Insert new dates line after copyright marker
            insert_index = markers[0].end()
            date_info = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*"
            content = content[:insert_index] + date_info + content[insert_index:]
            with open(file_path, "w", encoding="utf-8") as f:
                f.write(content)


def get_all_model_cards():
    """Get all model cards from the docs path"""

    all_files = os.listdir(DOCS_PATH)
    model_cards = []
    for file in all_files:
        if file.endswith(".md"):
            model_name = os.path.splitext(file)[0]
            if model_name not in ["auto", "timm_wrapper"]:
                model_cards.append(model_name)
    return sorted(model_cards)


def main(all=False, models=None, check_only=False):
    if check_only:
        # Check all model cards for missing dates
        all_model_cards = get_all_model_cards()
        missing_dates = check_missing_dates(all_model_cards)

        # Check modified model cards for incorrect dates
        modified_cards = get_modified_cards()
        incorrect_dates = check_incorrect_dates(modified_cards)

        if missing_dates or incorrect_dates:
            problematic_cards = missing_dates + incorrect_dates
            model_names = [card.replace(".md", "") for card in problematic_cards]
            raise ValueError(
                f"Missing or incorrect dates in the following model cards: {' '.join(problematic_cards)}\n"
                f"Run `python utils/add_dates.py --models {' '.join(model_names)}` to fix them."
            )
        return

    # Determine which model cards to process
    if all:
        model_cards = get_all_model_cards()
    elif models:
        model_cards = models
    else:
        model_cards = get_modified_cards()
        if not model_cards:
            return

    insert_dates(model_cards)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Add release and commit dates to model cards")
    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument("--models", nargs="+", help="Specify model cards to process (without .md extension)")
    group.add_argument("--all", action="store_true", help="Process all model cards in the docs directory")
    group.add_argument("--check-only", action="store_true", help="Check if the dates are already present")

    args = parser.parse_args()
    try:
        main(args.all, args.models, args.check_only)
    except subprocess.CalledProcessError as e:
        print(
            f"An error occurred while executing git commands but it can be ignored (git issue) most probably local: {e}"
        )