moe-routing-algorithm / moe-in-transformers.py
ariG23498's picture
ariG23498 HF Staff
Create moe-in-transformers.py
b9815de verified
"""
Analyze model introductions in the Transformers repo over the last ~2 years and
classify each introduced model as "moe" vs "dense" using a heuristic regex.
Outputs (in ./moe_dense_analysis):
- moe_dense_models_raw.csv : all models + inferred intro date + moe/dense label
- moe_dense_models_2y_window.csv : only models introduced in the last ~2 years
- moe_dense_2y_timeline.csv : monthly cumulative counts (moe/dense/total) over the window
- moe_dense_2y_timeline.png : plot of cumulative counts
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import calendar
import csv
import datetime as dt
import re
import subprocess
from pathlib import Path
import matplotlib
matplotlib.use("Agg") # headless backend for saving figures on CI/servers
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
# -----------------------------------------------------------------------------
# Repo paths / output directory
# -----------------------------------------------------------------------------
repo = Path(".").resolve()
models_root = repo / "src/transformers/models"
if not models_root.exists():
raise SystemExit("Run this from the transformers repo root.")
out_dir = repo / "moe_dense_analysis"
out_dir.mkdir(parents=True, exist_ok=True)
# -----------------------------------------------------------------------------
# Date window: last ~2 years from "today"
# -----------------------------------------------------------------------------
today = dt.date.today()
# Handle Feb 29 gracefully when subtracting years
try:
start_date = today.replace(year=today.year - 2)
except ValueError:
# If today is Feb 29 and (today.year - 2) is not a leap year, fallback to Feb 28
start_date = today.replace(year=today.year - 2, day=28)
end_date = today
# -----------------------------------------------------------------------------
# Discover model directories
#
# We consider a directory to be a "model" if it contains modeling_<name>.py
# (e.g. src/transformers/models/llama/modeling_llama.py)
# -----------------------------------------------------------------------------
model_names = []
for model_dir in sorted(models_root.iterdir()):
if not model_dir.is_dir():
continue
if model_dir.name.startswith("__"):
continue
modeling_file = model_dir / f"modeling_{model_dir.name}.py"
if modeling_file.exists():
model_names.append(model_dir.name)
model_name_set = set(model_names)
# -----------------------------------------------------------------------------
# Infer intro date per model using git:
#
# We use git log restricted to "added files" under src/transformers/models, and
# record the earliest date where any file under that model directory was added.
#
# NOTE: This is a heuristic, not a perfect "model introduced" definition.
# -----------------------------------------------------------------------------
git_out = subprocess.run(
[
"git",
"log",
"--diff-filter=A", # only "added file" changes
"--name-only", # list file paths
"--format=DATE %ad", # insert a marker line with the commit date
"--date=short", # YYYY-MM-DD
"--",
"src/transformers/models",
],
cwd=repo,
check=True,
text=True,
capture_output=True,
).stdout
intro_dates = {} # model_name -> earliest YYYY-MM-DD date string we observed
current_date = None # date string for the current commit chunk in git_out
for raw_line in git_out.splitlines():
line = raw_line.strip()
if not line:
continue
# Example marker: "DATE 2024-01-10"
if line.startswith("DATE "):
current_date = line.split(" ", 1)[1]
continue
# Only consider model paths after we've seen a DATE marker
if current_date is None:
continue
if not line.startswith("src/transformers/models/"):
continue
# Expected path structure:
# src/transformers/models/<model_name>/...
parts = line.split("/")
if len(parts) < 4:
continue
model_name = parts[3]
if model_name not in model_name_set:
continue
# Keep the earliest date we've seen for this model
old = intro_dates.get(model_name)
if old is None or current_date < old:
intro_dates[model_name] = current_date
# -----------------------------------------------------------------------------
# MoE heuristic:
#
# Search for class definitions in modeling_<name>.py where the class name contains
# MoE/MOE/Moe or Expert/Experts, AND subclasses nn.Module or torch.nn.Module.
#
# If we find at least one such class, label model as "moe", else "dense".
# -----------------------------------------------------------------------------
moe_class_re = re.compile(
r"^class\s+([A-Za-z0-9_]*(?:MoE|MOE|Moe|Expert|Experts)[A-Za-z0-9_]*)"
r"\s*\(\s*(?:nn|torch\.nn)\.Module\s*\)\s*:",
re.MULTILINE,
)
records = []
for model_name in model_names:
intro = intro_dates.get(model_name)
if intro is None:
# If we couldn't find an intro date, skip it (could be missing due to heuristic)
continue
modeling_file = models_root / model_name / f"modeling_{model_name}.py"
text = modeling_file.read_text(encoding="utf-8", errors="ignore")
matches = sorted(set(moe_class_re.findall(text)))
label = "moe" if matches else "dense"
records.append(
{
"model": model_name,
"introduced_date": intro, # YYYY-MM-DD (string)
"is_moe": label, # "moe" or "dense"
"moe_class_matches": ";".join(matches), # matched class names, if any
"modeling_file": str(modeling_file.relative_to(repo)),
}
)
# Sort by intro date then name for stable outputs
records.sort(key=lambda row: (row["introduced_date"], row["model"]))
# -----------------------------------------------------------------------------
# Restrict to 2-year window
# -----------------------------------------------------------------------------
window_records = []
for row in records:
intro_obj = dt.datetime.strptime(row["introduced_date"], "%Y-%m-%d").date()
if start_date <= intro_obj <= end_date:
row_copy = dict(row)
row_copy["intro_obj"] = intro_obj # store parsed date for comparisons
window_records.append(row_copy)
window_records.sort(key=lambda row: (row["intro_obj"], row["model"]))
# -----------------------------------------------------------------------------
# Build monthly timeline points: start_date, then each next month, ending at end_date
#
# We try to keep the day-of-month stable (e.g., the 19th of each month), but clamp
# to the last day of month if needed (e.g., Feb for day=31).
# -----------------------------------------------------------------------------
points = [start_date]
while points[-1] < end_date:
last = points[-1]
# Compute next month safely
year = last.year + (last.month // 12)
month = 1 if last.month == 12 else last.month + 1
day = min(last.day, calendar.monthrange(year, month)[1])
next_month = dt.date(year, month, day)
if next_month > end_date:
break
points.append(next_month)
# Ensure the last point is exactly end_date
if points[-1] != end_date:
points.append(end_date)
# -----------------------------------------------------------------------------
# Compute cumulative counts at each timeline point
# -----------------------------------------------------------------------------
timeline_rows = []
for point in points:
moe_cum = sum(
1
for row in window_records
if row["is_moe"] == "moe" and row["intro_obj"] <= point
)
dense_cum = sum(
1
for row in window_records
if row["is_moe"] == "dense" and row["intro_obj"] <= point
)
timeline_rows.append(
{
"date": point.isoformat(),
"moe_cumulative": moe_cum,
"dense_cumulative": dense_cum,
"total_cumulative": moe_cum + dense_cum,
}
)
# -----------------------------------------------------------------------------
# Write CSV outputs
# -----------------------------------------------------------------------------
raw_csv = out_dir / "moe_dense_models_raw.csv"
with raw_csv.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f,
fieldnames=["model", "introduced_date", "is_moe", "moe_class_matches", "modeling_file"],
)
writer.writeheader()
writer.writerows(records)
window_csv = out_dir / "moe_dense_models_2y_window.csv"
with window_csv.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f,
fieldnames=["model", "introduced_date", "is_moe", "moe_class_matches", "modeling_file"],
)
writer.writeheader()
for row in window_records:
copy_row = dict(row)
copy_row.pop("intro_obj", None) # internal-only field
writer.writerow(copy_row)
timeline_csv = out_dir / "moe_dense_2y_timeline.csv"
with timeline_csv.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(
f,
fieldnames=["date", "moe_cumulative", "dense_cumulative", "total_cumulative"],
)
writer.writeheader()
writer.writerows(timeline_rows)
# -----------------------------------------------------------------------------
# Plot cumulative counts over time
# -----------------------------------------------------------------------------
x = [dt.datetime.strptime(row["date"], "%Y-%m-%d").date() for row in timeline_rows]
# y_dense = [row["dense_cumulative"] for row in timeline_rows]
y_moe = [row["moe_cumulative"] for row in timeline_rows]
plt.figure(figsize=(11, 6))
# plt.plot(x, y_dense, label="Dense cumulative", linewidth=2.2)
plt.plot(x, y_moe, label="MoE cumulative", linewidth=2.2)
# plt.title(f"MoE vs Dense model introductions ({start_date} to {end_date})")
plt.title(f"MoE model introductions ({start_date} to {end_date})")
plt.xlabel("Date")
plt.ylabel("Model count")
plt.grid(alpha=0.3)
plt.legend()
ax = plt.gca()
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plot_png = out_dir / "moe_dense_2y_timeline.png"
plt.savefig(plot_png, dpi=180)
# -----------------------------------------------------------------------------
# Print summary
# -----------------------------------------------------------------------------
dense_total = sum(1 for row in window_records if row["is_moe"] == "dense")
moe_total = sum(1 for row in window_records if row["is_moe"] == "moe")
print(f"Window: {start_date} -> {end_date}")
print(f"Introduced in window: dense={dense_total}, moe={moe_total}, total={dense_total + moe_total}")
print(f"Wrote {raw_csv}")
print(f"Wrote {window_csv}")
print(f"Wrote {timeline_csv}")
print(f"Wrote {plot_png}")