File size: 7,313 Bytes
064e771 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | from __future__ import annotations
import json
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
NOTEBOOK_DIR = ROOT / "notebooks"
USERNAME = "Sami94"
STUDENT_NAME = "Sami Chellia"
def source(text: str) -> list[str]:
return text.splitlines(keepends=True)
def noop(reason: str) -> list[str]:
return source(f"# Local patch for {STUDENT_NAME} ({USERNAME})\n# Skipped: {reason}\n")
def hf_whoami() -> list[str]:
return source(
"from huggingface_hub import HfApi\n"
"print('HF user:', HfApi().whoami().get('name'))\n"
)
def read_nb(path: Path) -> dict:
return json.loads(path.read_text(encoding="utf-8"))
def write_nb(path: Path, nb: dict) -> None:
nb.setdefault("metadata", {})
nb["metadata"]["student_name"] = STUDENT_NAME
nb["metadata"]["hf_username"] = USERNAME
path.write_text(json.dumps(nb, ensure_ascii=False, indent=1), encoding="utf-8")
def replace_in_all_cells(nb: dict, replacements: dict[str, str]) -> None:
for cell in nb.get("cells", []):
if "source" not in cell:
continue
text = "".join(cell["source"])
for old, new in replacements.items():
text = text.replace(old, new)
cell["source"] = source(text)
def set_cell(nb: dict, one_based_index: int, text: list[str]) -> None:
nb["cells"][one_based_index - 1]["source"] = text
def patch_unit1(nb: dict) -> None:
set_cell(nb, 15, noop("apt install swig/cmake is Linux-only; install Windows dependencies outside the notebook."))
set_cell(nb, 18, noop("Linux OpenGL/Xvfb setup is not needed in this Windows local run."))
set_cell(nb, 20, noop("Colab kernel restart would kill local execution."))
set_cell(nb, 21, noop("pyvirtualdisplay/Xvfb is Linux-only."))
set_cell(nb, 39, noop("placeholder cell; solution cell below is used."))
set_cell(nb, 43, noop("placeholder cell; solution cell below is used."))
set_cell(nb, 47, noop("placeholder cell; solution cell below is used."))
set_cell(nb, 54, hf_whoami())
set_cell(nb, 58, noop("duplicate placeholder push cell; completed push cell below is used."))
replace_in_all_cells(
nb,
{
'repo_id = "ThomasSimonini/ppo-LunarLander-v2"': 'repo_id = "Sami94/ppo-LunarLander-v2"',
'repo_id = "Classroom-workshop/assignment2-omar"': 'repo_id = "Sami94/ppo-LunarLander-v2"',
},
)
def patch_unit2(nb: dict) -> None:
set_cell(nb, 15, noop("Linux OpenGL/Xvfb setup is not needed in this Windows local run."))
set_cell(nb, 17, noop("Colab kernel restart would kill local execution."))
set_cell(nb, 18, noop("pyvirtualdisplay/Xvfb is Linux-only."))
for index in [25, 35, 36, 43, 47, 53]:
set_cell(nb, index, noop("placeholder cell; solution cell below is used."))
set_cell(nb, 72, hf_whoami())
replace_in_all_cells(
nb,
{
'username = "" # FILL THIS': 'username = "Sami94"',
'repo_name = "" # FILL THIS': 'repo_name = "q-Taxi-v3"',
'repo_id="ThomasSimonini/q-Taxi-v3"': 'repo_id="Sami94/q-Taxi-v3"',
'repo_id="ThomasSimonini/q-FrozenLake-v1-no-slippery"': 'repo_id="Sami94/q-FrozenLake-v1-4x4-noSlippery"',
},
)
def patch_unit3(nb: dict) -> None:
for index in [16, 20]:
set_cell(nb, index, noop("Linux apt/OpenGL/Xvfb setup is not needed in this Windows local run."))
for index in [26, 30, 41]:
set_cell(nb, index, noop("placeholder command; completed command below is used."))
set_cell(nb, 37, hf_whoami())
replace_in_all_cells(nb, {'-orga ThomasSimonini': '-orga Sami94'})
def patch_unit4(nb: dict) -> None:
set_cell(nb, 14, noop("Linux OpenGL/Xvfb setup is not needed in this Windows local run."))
set_cell(nb, 15, noop("pyvirtualdisplay/Xvfb is Linux-only."))
for index in [32, 45, 74]:
set_cell(nb, index, noop("placeholder cell; solution cell below is used."))
set_cell(nb, 63, hf_whoami())
set_cell(
nb,
66,
source(
'repo_id = "Sami94/Reinforce-CartPole-v1"\n'
"push_to_hub(repo_id,\n"
" cartpole_policy,\n"
" cartpole_hyperparameters,\n"
" eval_env,\n"
" video_fps=30\n"
" )\n"
),
)
set_cell(
nb,
83,
source(
'repo_id = "Sami94/Reinforce-Pixelcopter-PLE-v0"\n'
"push_to_hub(repo_id,\n"
" pixelcopter_policy,\n"
" pixelcopter_hyperparameters,\n"
" eval_env,\n"
" video_fps=30\n"
" )\n"
),
)
def patch_unit5(nb: dict) -> None:
replace_in_all_cells(
nb,
{
'--repo-id="ThomasSimonini/ppo-SnowballTarget"': '--repo-id="Sami94/ppo-SnowballTarget"',
},
)
set_cell(nb, 44, noop("placeholder push command; fill only after local Unity training exists."))
set_cell(nb, 61, noop("placeholder push command; fill only after local Unity training exists."))
def patch_unit6(nb: dict) -> None:
set_cell(nb, 13, noop("Linux OpenGL/Xvfb setup is not needed in this Windows local run."))
set_cell(nb, 14, noop("pyvirtualdisplay/Xvfb is Linux-only."))
for index in [29, 33]:
set_cell(nb, index, noop("placeholder cell; solution cell below is used."))
set_cell(nb, 44, hf_whoami())
replace_in_all_cells(nb, {'repo_id=f"ThomasSimonini/a2c-{env_id}"': 'repo_id=f"Sami94/a2c-{env_id}"'})
def patch_unit8_part1(nb: dict) -> None:
set_cell(nb, 14, noop("Linux OpenGL/Xvfb/swig setup is not needed in this Windows local run."))
set_cell(nb, 34, hf_whoami())
replace_in_all_cells(
nb,
{
'--repo-id="YOUR_REPO_ID"': '--repo-id="Sami94/ppo-LunarLander-v2-cleanrl"',
'default="ThomasSimonini/ppo-CartPole-v1"': 'default="Sami94/ppo-CartPole-v1"',
},
)
def patch_unit8_part2(nb: dict) -> None:
set_cell(nb, 11, noop("ViZDoom apt dependencies are Linux-only and cannot run in this Windows notebook."))
set_cell(nb, 28, hf_whoami())
replace_in_all_cells(nb, {'hf_username = "ThomasSimonini"': 'hf_username = "Sami94"'})
def patch_bonus(nb: dict) -> None:
replace_in_all_cells(nb, {'--repo-id="ThomasSimonini/ppo-Huggy"': '--repo-id="Sami94/ppo-Huggy"'})
PATCHERS = {
"notebooks__bonus-unit1__bonus-unit1.ipynb": patch_bonus,
"notebooks__bonus-unit1__bonus_unit1.ipynb": patch_bonus,
"notebooks__unit1__unit1.ipynb": patch_unit1,
"notebooks__unit2__unit2.ipynb": patch_unit2,
"notebooks__unit3__unit3.ipynb": patch_unit3,
"notebooks__unit4__unit4.ipynb": patch_unit4,
"notebooks__unit5__unit5.ipynb": patch_unit5,
"notebooks__unit6__unit6.ipynb": patch_unit6,
"notebooks__unit8__unit8_part1.ipynb": patch_unit8_part1,
"notebooks__unit8__unit8_part2.ipynb": patch_unit8_part2,
}
def main() -> None:
for filename, patcher in PATCHERS.items():
path = NOTEBOOK_DIR / filename
nb = read_nb(path)
patcher(nb)
write_nb(path, nb)
print(f"patched {path}")
if __name__ == "__main__":
main()
|