Zeke Long commited on
Commit
0d4e1b9
·
1 Parent(s): ecec64b

Updating to adress the bottleneck in the poll for next roun

Browse files
README.md CHANGED
@@ -55,7 +55,7 @@ AGGREGATOR_URL = "https://your-username-your-space.hf.space"
55
  NODE_SECRET = "my-super-secret-123" # must match the Space secret
56
  ```
57
 
58
- Use the Space’s **direct app host** as `AGGREGATOR_URL`: **`https://<owner>-<space-name>.hf.space`** (open your Space → use the **App** tab URL, or copy from the Space card). It is **not** `huggingface.co/spaces/user/repo`, **not** `https://<owner>/<space>.hf.space` (a slash makes the client try to resolve `<owner>` as the hostname and DNS fails), and **not** a model repo id. The client accepts the host **with or without** `https://` (missing scheme defaults to `https://`) and **rewrites** the common `owner/space.hf.space` typo into `owner-space.hf.space`. API paths are `/submit`, `/status`, `/health`, `/reset`.
59
 
60
  Then call the aggregator client:
61
 
@@ -80,17 +80,9 @@ When the last node completes a round, the client may receive **`merge_failed`**
80
 
81
  Nodes do **not** need `MODEL_REPO_ID` — only the aggregator uses it to download/upload adapter weights.
82
 
83
- #### 3. Local testing
84
 
85
- Export the variables before running the server:
86
-
87
- ```bash
88
- export HF_TOKEN="hf_..."
89
- export MODEL_REPO_ID="your-username/your-model-repo"
90
- export NODE_SECRET="local_test_secret"
91
- ```
92
-
93
- Without these, the app defaults to empty strings (merge is skipped) and `"local_test_secret"` for the node secret.
94
 
95
  ## Operator notes
96
 
@@ -98,6 +90,7 @@ Without these, the app defaults to empty strings (merge is skipped) and `"local_
98
  - **Secrets:** Never commit `HF_TOKEN`, `NODE_SECRET`, or tokens in git remotes. Use Space **Repository secrets** and a local env or credential helper.
99
  - **Reset:** `POST /reset` with JSON `{"secret_key": "<ADMIN_SECRET or NODE_SECRET>"}` clears round state to 1. If `ADMIN_SECRET` is set on the Space, use that; otherwise use `NODE_SECRET`.
100
  - **Protected status:** When `STATUS_READ_SECRET` is set, pass the same value as header `X-Status-Secret` (see `aggregator_client.check_aggregator` / `poll_for_next_round` argument `status_secret`, and notebook `CONFIG["status_read_secret"]`).
 
101
  - **Rate limits:** `POST /submit` is limited per client IP (first `X-Forwarded-For` hop when present). Override with **`RATE_LIMIT_SUBMIT_MAX`** (default 120) and **`RATE_LIMIT_SUBMIT_WINDOW_SEC`** (default 60).
102
  - **Logs:** Set **`LOG_LEVEL`** (e.g. `DEBUG`, `INFO`, `WARNING`) for the `peft.aggregator` logger on stdout.
103
  - **Public Space:** `GET /status` is world-readable on a public Space; use a private Space if round visibility matters.
@@ -109,14 +102,118 @@ Without these, the app defaults to empty strings (merge is skipped) and `"local_
109
  - **Nodes:** 3 x Google Colab free T4 GPU
110
  - **Aggregation:** FedAvg over adapter states
111
 
112
- ## Local Testing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  ```bash
 
 
 
 
 
115
  pip install -r requirements.txt
116
  pip install torch --index-url https://download.pytorch.org/whl/cpu
 
117
  uvicorn app:app --host 0.0.0.0 --port 7860 --reload
118
  ```
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  **Note:** `requirements.txt` caps **FastAPI** and **Starlette** below versions that ship **Starlette 1.x**. Gradio **4.44.x** is incompatible with that stack (the Space would return **500** on `GET /` with Jinja `unhashable type: 'dict'`). Upgrade Gradio before raising those caps.
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  Deadline: April 20, 2026 · Lead target April 13
 
55
  NODE_SECRET = "my-super-secret-123" # must match the Space secret
56
  ```
57
 
58
+ Use the Space’s **direct app host** as `AGGREGATOR_URL`: **`https://<owner>-<space-name>.hf.space`** (open your Space → use the **App** tab URL, or copy from the Space card). It is **not** `huggingface.co/spaces/user/repo`, **not** `https://<owner>/<space>.hf.space` (a slash makes the client try to resolve `<owner>` as the hostname and DNS fails), and **not** a model repo id. The client accepts the host **with or without** `https://` (missing scheme defaults to `https://`), fixes **`https:host`** typos (missing slashes after the scheme), and **rewrites** the common `owner/space.hf.space` mistake into `owner-space.hf.space`. API paths are `/submit`, `/status`, `/health`, `/reset`.
59
 
60
  Then call the aggregator client:
61
 
 
80
 
81
  Nodes do **not** need `MODEL_REPO_ID` — only the aggregator uses it to download/upload adapter weights.
82
 
83
+ #### 3. Local or Space runtime env
84
 
85
+ For **local** runs, export `HF_TOKEN`, `MODEL_REPO_ID`, and `NODE_SECRET` (or rely on defaults described in **Testing** below). On the **Hugging Face Space**, set the same keys as **Repository secrets**.
 
 
 
 
 
 
 
 
86
 
87
  ## Operator notes
88
 
 
90
  - **Secrets:** Never commit `HF_TOKEN`, `NODE_SECRET`, or tokens in git remotes. Use Space **Repository secrets** and a local env or credential helper.
91
  - **Reset:** `POST /reset` with JSON `{"secret_key": "<ADMIN_SECRET or NODE_SECRET>"}` clears round state to 1. If `ADMIN_SECRET` is set on the Space, use that; otherwise use `NODE_SECRET`.
92
  - **Protected status:** When `STATUS_READ_SECRET` is set, pass the same value as header `X-Status-Secret` (see `aggregator_client.check_aggregator` / `poll_for_next_round` argument `status_secret`, and notebook `CONFIG["status_read_secret"]`).
93
+ - **401 Unauthorized:** (1) **`POST /submit`** — JSON `secret_key` must match the Space **`NODE_SECRET`** (same string in every Colab `CONFIG["node_secret"]`). (2) **`GET /status`** — if the Space defines **`STATUS_READ_SECRET`**, clients must send that value (notebook `CONFIG["status_read_secret"]`, or clear `STATUS_READ_SECRET` on the Space if you do not need it). **`GET /health`** has no secret. A **private** Hugging Face Space can also return 401 at the edge before your app runs — open the Space in the browser while logged in, or check Space visibility settings.
94
  - **Rate limits:** `POST /submit` is limited per client IP (first `X-Forwarded-For` hop when present). Override with **`RATE_LIMIT_SUBMIT_MAX`** (default 120) and **`RATE_LIMIT_SUBMIT_WINDOW_SEC`** (default 60).
95
  - **Logs:** Set **`LOG_LEVEL`** (e.g. `DEBUG`, `INFO`, `WARNING`) for the `peft.aggregator` logger on stdout.
96
  - **Public Space:** `GET /status` is world-readable on a public Space; use a private Space if round visibility matters.
 
102
  - **Nodes:** 3 x Google Colab free T4 GPU
103
  - **Aggregation:** FedAvg over adapter states
104
 
105
+ ## Testing
106
+
107
+ ### Automated tests (CI / laptop)
108
+
109
+ From the repo root, use a virtualenv with **Python 3.10+** (3.12 is fine). Install dependencies and run the suite:
110
+
111
+ ```bash
112
+ python -m venv .venv
113
+ source .venv/bin/activate # Windows: .venv\Scripts\activate
114
+ pip install -r requirements.txt
115
+ pip install pytest httpx "gradio==4.44.1"
116
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
117
+ pytest tests/ -q
118
+ ```
119
+
120
+ Tests exercise **FastAPI** routes (`/health`, `/status`, `/submit`, `/reset`) via `TestClient`. They set **`HF_TOKEN`** and **`MODEL_REPO_ID`** empty so **FedAvg is skipped** and no Hub network calls are required. **`aggregator_client`** URL normalization is covered in `tests/test_aggregator_client.py`.
121
+
122
+ ### Run locally (same stack as the Space)
123
 
124
  ```bash
125
+ export NODE_SECRET="local_test_secret"
126
+ # Optional: real FedAvg on Hub (otherwise merge is skipped when the third node submits)
127
+ export HF_TOKEN=""
128
+ export MODEL_REPO_ID=""
129
+
130
  pip install -r requirements.txt
131
  pip install torch --index-url https://download.pytorch.org/whl/cpu
132
+ pip install "gradio==4.44.1"
133
  uvicorn app:app --host 0.0.0.0 --port 7860 --reload
134
  ```
135
 
136
+ Open **`http://127.0.0.1:7860/`** for the Gradio dashboard. Smoke-test JSON endpoints:
137
+
138
+ ```bash
139
+ BASE="http://127.0.0.1:7860"
140
+ curl -sS "$BASE/health"
141
+ curl -sS "$BASE/status"
142
+ curl -sS -X POST "$BASE/submit" -H "Content-Type: application/json" \
143
+ -d "{\"node_id\":\"node_a\",\"secret_key\":\"local_test_secret\"}"
144
+ ```
145
+
146
+ If the Space uses **`STATUS_READ_SECRET`**, mirror that locally:
147
+
148
+ ```bash
149
+ curl -sS "$BASE/status" -H "X-Status-Secret: your-secret"
150
+ ```
151
+
152
  **Note:** `requirements.txt` caps **FastAPI** and **Starlette** below versions that ship **Starlette 1.x**. Gradio **4.44.x** is incompatible with that stack (the Space would return **500** on `GET /` with Jinja `unhashable type: 'dict'`). Upgrade Gradio before raising those caps.
153
 
154
+ ### Docker (parity with the HF Space)
155
+
156
+ ```bash
157
+ docker build -t peft-aggregator .
158
+ docker run --rm -p 7860:7860 \
159
+ -e NODE_SECRET="local_test_secret" \
160
+ -e HF_TOKEN="" \
161
+ -e MODEL_REPO_ID="" \
162
+ peft-aggregator
163
+ ```
164
+
165
+ Then use the same **`curl`** examples with **`BASE=http://127.0.0.1:7860`**.
166
+
167
+ ### Test the deployed Hugging Face Space
168
+
169
+ Use your Space **App** URL (hyphenated **`.hf.space`** host). Set **`BASE`** and **`SECRET`** to match **Repository secrets** on the Space.
170
+
171
+ ```bash
172
+ BASE="https://YOUR_OWNER-YOUR_SPACE_NAME.hf.space"
173
+ SECRET="your-node-secret-from-space-settings"
174
+
175
+ curl -sS "$BASE/health"
176
+ ```
177
+
178
+ **`GET /status`** — if **`STATUS_READ_SECRET`** is set on the Space, add the header; otherwise a public Space returns JSON without auth:
179
+
180
+ ```bash
181
+ curl -sS "$BASE/status"
182
+ # or: curl -sS "$BASE/status" -H "X-Status-Secret: $STATUS_READ_SECRET"
183
+ ```
184
+
185
+ **`POST /submit`** — any node can submit independently; the first responses are **`"status":"submitted"`** with **`remaining`** until all three IDs have submitted for the current round. Omit **`round_num`** for a quick smoke test, or set **`round_num`** to the value returned by **`/status`** as **`current_round`** (mismatch → **409**).
186
+
187
+ ```bash
188
+ for id in node_a node_b node_c; do
189
+ curl -sS -X POST "$BASE/submit" -H "Content-Type: application/json" \
190
+ -d "{\"node_id\":\"$id\",\"secret_key\":\"$SECRET\",\"avg_loss\":1.0,\"steps_completed\":10}"
191
+ echo
192
+ done
193
+ ```
194
+
195
+ After the third submit you should see **`"status":"round_complete"`** (with **`HF_TOKEN`/`MODEL_REPO_ID`** configured and valid adapter files on the Hub) or a message that merge was **skipped** / **`merge_failed`** if Hub setup is incomplete — both outcomes confirm the Space is running the aggregation logic.
196
+
197
+ **`POST /reset`** — use **`ADMIN_SECRET`** if the Space defines it, else **`NODE_SECRET`**:
198
+
199
+ ```bash
200
+ curl -sS -X POST "$BASE/reset" -H "Content-Type: application/json" \
201
+ -d "{\"secret_key\":\"$SECRET\"}"
202
+ ```
203
+
204
+ **Python** — the same checks with **`aggregator_client`**:
205
+
206
+ ```python
207
+ from aggregator_client import check_aggregator, health_aggregator, notify_aggregator
208
+
209
+ BASE = "https://YOUR_OWNER-YOUR_SPACE_NAME.hf.space"
210
+ SECRET = "your-node-secret"
211
+
212
+ health_aggregator(BASE)
213
+ check_aggregator(BASE, status_secret=None) # or status_secret="..." if configured
214
+ notify_aggregator(BASE, "node_a", SECRET, round_num=1)
215
+ ```
216
+
217
+ **Private Space:** you may need to be logged into Hugging Face in a browser to open **`/`**; API calls from Colab or scripts still use **`NODE_SECRET`** on **`/submit`** and **`X-Status-Secret`** on **`/status`** when applicable — they do not use your HF login cookie.
218
+
219
  Deadline: April 20, 2026 · Lead target April 13
aggregator_client.py CHANGED
@@ -15,20 +15,63 @@ from urllib.parse import urlsplit, urlunsplit
15
  import requests
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def _normalize_aggregator_base_url(url: str) -> str:
19
  """Strip whitespace/slashes and ensure an HTTP(S) scheme for requests.
20
 
21
  Colab configs often omit ``https://``, which causes requests to raise
22
  ``MissingSchema``.
23
 
24
- Also fixes a common mistake: pasting ``https://OWNER/SPACE.hf.space`` (slash)
25
- instead of the real host ``https://OWNER-SPACE.hf.space``. Otherwise the HTTP
26
- client treats ``OWNER`` as the hostname and fails DNS.
27
  """
28
  base = url.strip().rstrip("/")
29
  if not base:
30
  raise ValueError("aggregator_url is empty")
31
- if "://" not in base:
 
 
 
 
 
 
 
 
 
32
  base = "https://" + base.lstrip("/")
33
 
34
  parts = urlsplit(base)
@@ -109,7 +152,7 @@ def notify_aggregator(
109
  "(typically https://YOUR_SPACE_NAME.hf.space), not the huggingface.co/spaces "
110
  "HTML page URL. Path must be exactly /submit."
111
  )
112
- response.raise_for_status()
113
  data = response.json()
114
  if data.get("status") == "merge_failed":
115
  raise AggregatorMergeFailed(data)
@@ -119,7 +162,7 @@ def notify_aggregator(
119
  def poll_for_next_round(
120
  aggregator_url: str,
121
  current_round: int,
122
- poll_interval: int = 30,
123
  max_wait: int = 1800,
124
  status_secret: str | None = None,
125
  ) -> dict:
@@ -129,7 +172,7 @@ def poll_for_next_round(
129
  aggregator_url: Base URL of the aggregator Space (with or without
130
  ``https://``; missing scheme defaults to ``https://``).
131
  current_round: The round we just finished.
132
- poll_interval: Seconds between status checks.
133
  max_wait: Maximum seconds to wait before raising TimeoutError.
134
  status_secret: If the Space sets STATUS_READ_SECRET, pass it here
135
  (sent as X-Status-Secret on GET /status).
@@ -139,6 +182,7 @@ def poll_for_next_round(
139
 
140
  Raises:
141
  TimeoutError: If max_wait is exceeded.
 
142
  """
143
  base = _normalize_aggregator_base_url(aggregator_url)
144
  url = f"{base}/status"
@@ -148,19 +192,32 @@ def poll_for_next_round(
148
  while elapsed < max_wait:
149
  try:
150
  resp = requests.get(url, timeout=15, headers=headers)
151
- resp.raise_for_status()
152
  status = resp.json()
 
 
 
 
 
 
153
  agg_round = status.get("current_round", 0)
154
  if agg_round > current_round:
155
  print(f"[poll] Aggregator advanced to round {agg_round}")
156
  return status
 
 
 
 
 
 
 
 
 
 
 
157
  except requests.RequestException as e:
158
  print(f"[poll] Request error: {e}")
159
 
160
- print(
161
- f"[poll] Waiting for round {current_round + 1} "
162
- f"({elapsed}/{max_wait}s elapsed)..."
163
- )
164
  time.sleep(poll_interval)
165
  elapsed += poll_interval
166
 
@@ -183,7 +240,7 @@ def check_aggregator(
183
  timeout=timeout,
184
  headers=_status_headers(status_secret),
185
  )
186
- response.raise_for_status()
187
  return response.json()
188
 
189
 
@@ -199,7 +256,7 @@ def reset_aggregator(
199
  json={"secret_key": secret_key},
200
  timeout=timeout,
201
  )
202
- response.raise_for_status()
203
  return response.json()
204
 
205
 
@@ -207,5 +264,5 @@ def health_aggregator(aggregator_url: str, timeout: int = 10) -> dict:
207
  """Liveness probe via GET /health; does not depend on training state."""
208
  url = f"{_normalize_aggregator_base_url(aggregator_url)}/health"
209
  response = requests.get(url, timeout=timeout)
210
- response.raise_for_status()
211
  return response.json()
 
15
  import requests
16
 
17
 
18
+ def _detail_from_response(response: requests.Response) -> str:
19
+ try:
20
+ data = response.json()
21
+ if isinstance(data, dict):
22
+ d = data.get("detail")
23
+ if d is not None:
24
+ return str(d) if not isinstance(d, list) else str(d)
25
+ except Exception:
26
+ pass
27
+ text = (response.text or "").strip()
28
+ if text:
29
+ return text[:800]
30
+ return response.reason or ""
31
+
32
+
33
+ def _raise_for_aggregator_response(response: requests.Response, *, what: str) -> None:
34
+ """Raise HTTPError with server ``detail`` and hints for common operator mistakes."""
35
+ if response.status_code < 400:
36
+ return
37
+ detail = _detail_from_response(response)
38
+ msg = f"{what}: HTTP {response.status_code}"
39
+ if detail:
40
+ msg += f" — {detail}"
41
+ if response.status_code == 401:
42
+ msg += (
43
+ ". Hint: POST /submit needs JSON ``secret_key`` equal to the Space "
44
+ "``NODE_SECRET``. GET /status needs header ``X-Status-Secret`` when the "
45
+ "Space sets ``STATUS_READ_SECRET`` — pass ``status_secret=...`` from "
46
+ "``poll_for_next_round`` / ``check_aggregator``, or set "
47
+ "``CONFIG['status_read_secret']`` in the notebook to match the Space."
48
+ )
49
+ raise requests.HTTPError(msg, response=response)
50
+
51
+
52
  def _normalize_aggregator_base_url(url: str) -> str:
53
  """Strip whitespace/slashes and ensure an HTTP(S) scheme for requests.
54
 
55
  Colab configs often omit ``https://``, which causes requests to raise
56
  ``MissingSchema``.
57
 
58
+ Also fixes: ``https:host`` (missing ``//`` after the scheme); pasting
59
+ ``https://OWNER/SPACE.hf.space`` (slash) instead of ``https://OWNER-SPACE.hf.space``
60
+ (otherwise the client treats ``OWNER`` as the hostname and DNS fails).
61
  """
62
  base = url.strip().rstrip("/")
63
  if not base:
64
  raise ValueError("aggregator_url is empty")
65
+
66
+ low = base.lower()
67
+ # ``https:host`` / ``http:host`` (missing ``//``) does not contain ``://``;
68
+ # blindly prefixing ``https://`` would produce ``https://https:host`` and
69
+ # break urllib3 (InvalidURL).
70
+ if low.startswith("https:") and not low.startswith("https://"):
71
+ base = "https://" + base[6:].lstrip("/")
72
+ elif low.startswith("http:") and not low.startswith("http://"):
73
+ base = "http://" + base[5:].lstrip("/")
74
+ elif "://" not in base:
75
  base = "https://" + base.lstrip("/")
76
 
77
  parts = urlsplit(base)
 
152
  "(typically https://YOUR_SPACE_NAME.hf.space), not the huggingface.co/spaces "
153
  "HTML page URL. Path must be exactly /submit."
154
  )
155
+ _raise_for_aggregator_response(response, what="POST /submit")
156
  data = response.json()
157
  if data.get("status") == "merge_failed":
158
  raise AggregatorMergeFailed(data)
 
162
  def poll_for_next_round(
163
  aggregator_url: str,
164
  current_round: int,
165
+ poll_interval: int = 10,
166
  max_wait: int = 1800,
167
  status_secret: str | None = None,
168
  ) -> dict:
 
172
  aggregator_url: Base URL of the aggregator Space (with or without
173
  ``https://``; missing scheme defaults to ``https://``).
174
  current_round: The round we just finished.
175
+ poll_interval: Seconds between status checks (default 10).
176
  max_wait: Maximum seconds to wait before raising TimeoutError.
177
  status_secret: If the Space sets STATUS_READ_SECRET, pass it here
178
  (sent as X-Status-Secret on GET /status).
 
182
 
183
  Raises:
184
  TimeoutError: If max_wait is exceeded.
185
+ AggregatorMergeFailed: If the aggregator reports a merge error.
186
  """
187
  base = _normalize_aggregator_base_url(aggregator_url)
188
  url = f"{base}/status"
 
192
  while elapsed < max_wait:
193
  try:
194
  resp = requests.get(url, timeout=15, headers=headers)
195
+ _raise_for_aggregator_response(resp, what="GET /status")
196
  status = resp.json()
197
+
198
+ # Check for merge failure
199
+ merge_error = status.get("merge_error")
200
+ if merge_error:
201
+ raise AggregatorMergeFailed({"merge_result": merge_error})
202
+
203
  agg_round = status.get("current_round", 0)
204
  if agg_round > current_round:
205
  print(f"[poll] Aggregator advanced to round {agg_round}")
206
  return status
207
+
208
+ if status.get("merging"):
209
+ print(
210
+ f"[poll] FedAvg merge in progress "
211
+ f"({elapsed}/{max_wait}s elapsed)..."
212
+ )
213
+ else:
214
+ print(
215
+ f"[poll] Waiting for round {current_round + 1} "
216
+ f"({elapsed}/{max_wait}s elapsed)..."
217
+ )
218
  except requests.RequestException as e:
219
  print(f"[poll] Request error: {e}")
220
 
 
 
 
 
221
  time.sleep(poll_interval)
222
  elapsed += poll_interval
223
 
 
240
  timeout=timeout,
241
  headers=_status_headers(status_secret),
242
  )
243
+ _raise_for_aggregator_response(response, what="GET /status")
244
  return response.json()
245
 
246
 
 
256
  json={"secret_key": secret_key},
257
  timeout=timeout,
258
  )
259
+ _raise_for_aggregator_response(response, what="POST /reset")
260
  return response.json()
261
 
262
 
 
264
  """Liveness probe via GET /health; does not depend on training state."""
265
  url = f"{_normalize_aggregator_base_url(aggregator_url)}/health"
266
  response = requests.get(url, timeout=timeout)
267
+ _raise_for_aggregator_response(response, what="GET /health")
268
  return response.json()
app.py CHANGED
@@ -15,6 +15,7 @@ import logging
15
  import math
16
  import os
17
  import json
 
18
  import time
19
  import datetime
20
 
@@ -84,8 +85,12 @@ state = {
84
  "last_update": None,
85
  "node_metrics": {}, # {node_id: {loss, step, timestamp, ...}}
86
  "activity_log": [], # recent events for the activity feed
 
 
87
  }
88
 
 
 
89
  # Per-client timestamps (monotonic) for POST /submit rate limiting
90
  _submit_rate_buckets: dict[str, list[float]] = {}
91
 
@@ -212,6 +217,45 @@ def fedavg_merge() -> tuple[str, bool]:
212
  )
213
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  # ---------------------------------------------------------------------------
216
  # FastAPI endpoints
217
  # ---------------------------------------------------------------------------
@@ -294,6 +338,8 @@ def get_status(request: Request):
294
  if n not in state["submitted_nodes"]
295
  ],
296
  "last_update": state["last_update"],
 
 
297
  }
298
 
299
 
@@ -326,106 +372,90 @@ def submit_node(req: SubmitRequest):
326
  ),
327
  )
328
 
329
- if req.node_id in state["submitted_nodes"]:
330
- return {
331
- "status": "already_submitted",
332
- "current_round": state["current_round"],
333
- "submitted_nodes": state["submitted_nodes"],
334
- }
 
335
 
336
- state["submitted_nodes"].append(req.node_id)
337
- state["last_update"] = _timestamp()
338
- agg_log.info(
339
- "submit accepted node_id=%s round=%s progress=%s/%s",
340
- req.node_id,
341
- state["current_round"],
342
- len(state["submitted_nodes"]),
343
- len(CONFIG["expected_nodes"]),
344
- )
345
 
346
- # Store node metrics
347
- state["node_metrics"][req.node_id] = {
348
- "avg_loss": req.avg_loss,
349
- "steps_completed": req.steps_completed,
350
- "round": state["current_round"],
351
- "submitted_at": _timestamp(),
352
- }
 
 
353
 
354
- # Log activity
355
- loss_suffix = (
356
- f" (loss: {req.avg_loss:.4f})" if req.avg_loss is not None else ""
357
- )
358
- _log_activity(
359
- f"{req.node_id} submitted round {state['current_round']}{loss_suffix}"
360
- )
361
 
362
- # Check if all nodes have submitted
363
- if set(state["submitted_nodes"]) == set(CONFIG["expected_nodes"]):
 
 
364
  _log_activity(
365
- f"All nodes submitted — starting FedAvg for round {state['current_round']}"
366
  )
367
- merge_result, merge_ok = fedavg_merge()
368
 
369
- # Capture per-node losses and steps in history
370
- round_metrics = {
371
- nid: state["node_metrics"].get(nid, {}).get("avg_loss")
372
- for nid in CONFIG["expected_nodes"]
373
- }
374
- round_steps = {
375
- nid: state["node_metrics"].get(nid, {}).get("steps_completed")
376
- for nid in CONFIG["expected_nodes"]
377
- }
378
-
379
- if merge_ok:
380
- state["history"].append({
381
- "round": state["current_round"],
382
- "completed_at": _timestamp(),
383
- "merge_result": merge_result,
384
- "node_losses": round_metrics,
385
- "node_steps": round_steps,
386
- })
387
 
388
- _log_activity(f"Round {state['current_round']} FedAvg complete")
389
- agg_log.info(
390
- "fedavg_complete round=%s new_round=%s",
391
- state["current_round"],
392
- state["current_round"] + 1,
393
  )
 
 
 
 
 
 
 
 
 
 
394
 
395
- state["current_round"] += 1
396
- state["submitted_nodes"] = []
397
- state["node_metrics"] = {}
 
 
 
398
 
399
  return {
400
- "status": "round_complete",
401
- "merge_result": merge_result,
402
- "new_round": state["current_round"],
403
  }
404
 
405
- tail = merge_result if len(merge_result) <= 400 else merge_result[:400] + "..."
406
- agg_log.warning(
407
- "fedavg_failed round=%s detail=%s",
408
- state["current_round"],
409
- tail,
410
- )
411
- _log_activity(f"FedAvg failed (round {state['current_round']}): {merge_result}")
412
- state["submitted_nodes"] = []
413
  return {
414
- "status": "merge_failed",
415
- "merge_result": merge_result,
416
  "current_round": state["current_round"],
 
 
 
 
 
417
  }
418
 
419
- return {
420
- "status": "submitted",
421
- "current_round": state["current_round"],
422
- "submitted_nodes": state["submitted_nodes"],
423
- "remaining": [
424
- n for n in CONFIG["expected_nodes"]
425
- if n not in state["submitted_nodes"]
426
- ],
427
- }
428
-
429
 
430
  @app.post("/reset")
431
  def reset_state(req: ResetRequest):
@@ -440,6 +470,8 @@ def reset_state(req: ResetRequest):
440
  state["last_update"] = _timestamp()
441
  state["node_metrics"] = {}
442
  state["activity_log"] = []
 
 
443
 
444
  _log_activity("State reset to round 1")
445
 
 
15
  import math
16
  import os
17
  import json
18
+ import threading
19
  import time
20
  import datetime
21
 
 
85
  "last_update": None,
86
  "node_metrics": {}, # {node_id: {loss, step, timestamp, ...}}
87
  "activity_log": [], # recent events for the activity feed
88
+ "merging": False, # True while FedAvg is running in background
89
+ "merge_error": None, # set if background merge failed
90
  }
91
 
92
+ _state_lock = threading.Lock()
93
+
94
  # Per-client timestamps (monotonic) for POST /submit rate limiting
95
  _submit_rate_buckets: dict[str, list[float]] = {}
96
 
 
217
  )
218
 
219
 
220
+ def _background_merge(
221
+ merge_round: int,
222
+ round_metrics: dict,
223
+ round_steps: dict,
224
+ ) -> None:
225
+ """Run FedAvg in a background thread, then advance the round or record failure."""
226
+ try:
227
+ merge_result, merge_ok = fedavg_merge()
228
+ except Exception as exc:
229
+ merge_result, merge_ok = f"FedAvg exception: {exc}", False
230
+
231
+ with _state_lock:
232
+ if merge_ok:
233
+ state["history"].append({
234
+ "round": merge_round,
235
+ "completed_at": _timestamp(),
236
+ "merge_result": merge_result,
237
+ "node_losses": round_metrics,
238
+ "node_steps": round_steps,
239
+ })
240
+ _log_activity(f"Round {merge_round} FedAvg complete")
241
+ agg_log.info(
242
+ "fedavg_complete round=%s new_round=%s",
243
+ merge_round,
244
+ merge_round + 1,
245
+ )
246
+ state["current_round"] = merge_round + 1
247
+ state["submitted_nodes"] = []
248
+ state["node_metrics"] = {}
249
+ else:
250
+ tail = merge_result if len(merge_result) <= 400 else merge_result[:400] + "..."
251
+ agg_log.warning("fedavg_failed round=%s detail=%s", merge_round, tail)
252
+ _log_activity(f"FedAvg failed (round {merge_round}): {merge_result}")
253
+ state["merge_error"] = merge_result
254
+ state["submitted_nodes"] = []
255
+
256
+ state["merging"] = False
257
+
258
+
259
  # ---------------------------------------------------------------------------
260
  # FastAPI endpoints
261
  # ---------------------------------------------------------------------------
 
338
  if n not in state["submitted_nodes"]
339
  ],
340
  "last_update": state["last_update"],
341
+ "merging": state["merging"],
342
+ "merge_error": state["merge_error"],
343
  }
344
 
345
 
 
372
  ),
373
  )
374
 
375
+ with _state_lock:
376
+ if req.node_id in state["submitted_nodes"]:
377
+ return {
378
+ "status": "already_submitted",
379
+ "current_round": state["current_round"],
380
+ "submitted_nodes": state["submitted_nodes"],
381
+ }
382
 
383
+ if state["merging"]:
384
+ return {
385
+ "status": "merging",
386
+ "current_round": state["current_round"],
387
+ "submitted_nodes": state["submitted_nodes"],
388
+ }
 
 
 
389
 
390
+ state["submitted_nodes"].append(req.node_id)
391
+ state["last_update"] = _timestamp()
392
+ agg_log.info(
393
+ "submit accepted node_id=%s round=%s progress=%s/%s",
394
+ req.node_id,
395
+ state["current_round"],
396
+ len(state["submitted_nodes"]),
397
+ len(CONFIG["expected_nodes"]),
398
+ )
399
 
400
+ # Store node metrics
401
+ state["node_metrics"][req.node_id] = {
402
+ "avg_loss": req.avg_loss,
403
+ "steps_completed": req.steps_completed,
404
+ "round": state["current_round"],
405
+ "submitted_at": _timestamp(),
406
+ }
407
 
408
+ # Log activity
409
+ loss_suffix = (
410
+ f" (loss: {req.avg_loss:.4f})" if req.avg_loss is not None else ""
411
+ )
412
  _log_activity(
413
+ f"{req.node_id} submitted round {state['current_round']}{loss_suffix}"
414
  )
 
415
 
416
+ # Check if all nodes have submitted
417
+ all_submitted = set(state["submitted_nodes"]) == set(CONFIG["expected_nodes"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
+ if all_submitted:
420
+ state["merging"] = True
421
+ state["merge_error"] = None
422
+ _log_activity(
423
+ f"All nodes submitted — starting FedAvg for round {state['current_round']}"
424
  )
425
+ # Capture metrics under lock before spawning thread
426
+ round_metrics = {
427
+ nid: state["node_metrics"].get(nid, {}).get("avg_loss")
428
+ for nid in CONFIG["expected_nodes"]
429
+ }
430
+ round_steps = {
431
+ nid: state["node_metrics"].get(nid, {}).get("steps_completed")
432
+ for nid in CONFIG["expected_nodes"]
433
+ }
434
+ merge_round = state["current_round"]
435
 
436
+ thread = threading.Thread(
437
+ target=_background_merge,
438
+ args=(merge_round, round_metrics, round_steps),
439
+ daemon=True,
440
+ )
441
+ thread.start()
442
 
443
  return {
444
+ "status": "merging",
445
+ "current_round": state["current_round"],
446
+ "submitted_nodes": state["submitted_nodes"],
447
  }
448
 
 
 
 
 
 
 
 
 
449
  return {
450
+ "status": "submitted",
 
451
  "current_round": state["current_round"],
452
+ "submitted_nodes": state["submitted_nodes"],
453
+ "remaining": [
454
+ n for n in CONFIG["expected_nodes"]
455
+ if n not in state["submitted_nodes"]
456
+ ],
457
  }
458
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  @app.post("/reset")
461
  def reset_state(req: ResetRequest):
 
470
  state["last_update"] = _timestamp()
471
  state["node_metrics"] = {}
472
  state["activity_log"] = []
473
+ state["merging"] = False
474
+ state["merge_error"] = None
475
 
476
  _log_activity("State reset to round 1")
477
 
notebook/distributed_finetuning_peft_async.ipynb CHANGED
@@ -102,7 +102,8 @@
102
  " # Direct Space app host: https://YOURNAME-YOURSPACENAME.hf.space (from Space \"App\" URL).\n",
103
  " # Not huggingface.co/spaces/... — use one hostname with hyphens, e.g.\n",
104
  " # https://OWNER-SPACENAME.hf.space NOT https://OWNER/SPACENAME.hf.space (slash breaks DNS).\n",
105
- " # You may omit https://; aggregator_client fixes missing scheme and the OWNER/SPACE mistake.\n",
 
106
  " \"aggregator_url\": \"https://your-username-peft-aggregator.hf.space\",\n",
107
  " \"node_secret\": \"choose_any_shared_secret_string\",\n",
108
  " # Optional: if the Space sets STATUS_READ_SECRET, put the same value here (for polling /status).\n",
 
102
  " # Direct Space app host: https://YOURNAME-YOURSPACENAME.hf.space (from Space \"App\" URL).\n",
103
  " # Not huggingface.co/spaces/... — use one hostname with hyphens, e.g.\n",
104
  " # https://OWNER-SPACENAME.hf.space NOT https://OWNER/SPACENAME.hf.space (slash breaks DNS).\n",
105
+ " # Use https:// (two slashes), not https: (typo). You may omit the scheme entirely;\n",
106
+ " # aggregator_client fixes missing scheme, https:host, and OWNER/SPACE slash mistakes.\n",
107
  " \"aggregator_url\": \"https://your-username-peft-aggregator.hf.space\",\n",
108
  " \"node_secret\": \"choose_any_shared_secret_string\",\n",
109
  " # Optional: if the Space sets STATUS_READ_SECRET, put the same value here (for polling /status).\n",
tests/test_aggregator_client.py CHANGED
@@ -12,6 +12,20 @@ def test_normalize_aggregator_base_url_adds_https():
12
  assert ac._normalize_aggregator_base_url(" example.hf.space/ ") == "https://example.hf.space"
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def test_normalize_aggregator_base_url_preserves_explicit_scheme():
16
  assert ac._normalize_aggregator_base_url("http://127.0.0.1:7860") == "http://127.0.0.1:7860"
17
  assert ac._normalize_aggregator_base_url("https://x.hf.space/") == "https://x.hf.space"
 
12
  assert ac._normalize_aggregator_base_url(" example.hf.space/ ") == "https://example.hf.space"
13
 
14
 
15
+ def test_normalize_aggregator_base_url_fixes_https_missing_slashes():
16
+ """https:host (typo) must not become https://https:host."""
17
+ assert ac._normalize_aggregator_base_url(
18
+ "https:instruction-fine-tuning-budget.hf.space"
19
+ ) == ("https://instruction-fine-tuning-budget.hf.space")
20
+ assert ac._normalize_aggregator_base_url(
21
+ "HTTPS:instruction-fine-tuning-budget.hf.space/"
22
+ ) == ("https://instruction-fine-tuning-budget.hf.space")
23
+
24
+
25
+ def test_normalize_aggregator_base_url_fixes_http_missing_slashes():
26
+ assert ac._normalize_aggregator_base_url("http:127.0.0.1:7860") == "http://127.0.0.1:7860"
27
+
28
+
29
  def test_normalize_aggregator_base_url_preserves_explicit_scheme():
30
  assert ac._normalize_aggregator_base_url("http://127.0.0.1:7860") == "http://127.0.0.1:7860"
31
  assert ac._normalize_aggregator_base_url("https://x.hf.space/") == "https://x.hf.space"
tests/test_app.py CHANGED
@@ -6,6 +6,8 @@ Uses empty HF credentials so FedAvg skips Hub I/O (no network required).
6
 
7
  from __future__ import annotations
8
 
 
 
9
  import pytest
10
  from fastapi.testclient import TestClient
11
 
@@ -35,6 +37,8 @@ def client(monkeypatch):
35
  app_module.state["last_update"] = None
36
  app_module.state["node_metrics"] = {}
37
  app_module.state["activity_log"] = []
 
 
38
 
39
  with TestClient(app_module.app) as c:
40
  yield c, app_module
@@ -148,6 +152,14 @@ def test_submit_unknown_node(client):
148
  assert r.status_code == 400
149
 
150
 
 
 
 
 
 
 
 
 
151
  def test_full_round_merge_skipped_no_hf(client):
152
  tc, m = client
153
  secret = m.CONFIG["node_secret"]
@@ -168,9 +180,11 @@ def test_full_round_merge_skipped_no_hf(client):
168
  if nid != "node_c":
169
  assert body["status"] == "submitted"
170
  else:
171
- assert body["status"] == "round_complete"
172
- assert "Skipped merge" in body["merge_result"]
173
- assert body["new_round"] == 2
 
 
174
 
175
  st = tc.get("/status").json()
176
  assert st["current_round"] == 2
@@ -269,6 +283,8 @@ def test_gradio_markdown_helpers_no_crash(client):
269
  for nid in ("node_a", "node_b", "node_c"):
270
  tc.post("/submit", json={"node_id": nid, "secret_key": secret})
271
 
 
 
272
  m._loss_history_md()
273
  m._merged_adapters_md()
274
  m._activity_log_md()
 
6
 
7
  from __future__ import annotations
8
 
9
+ import time
10
+
11
  import pytest
12
  from fastapi.testclient import TestClient
13
 
 
37
  app_module.state["last_update"] = None
38
  app_module.state["node_metrics"] = {}
39
  app_module.state["activity_log"] = []
40
+ app_module.state["merging"] = False
41
+ app_module.state["merge_error"] = None
42
 
43
  with TestClient(app_module.app) as c:
44
  yield c, app_module
 
152
  assert r.status_code == 400
153
 
154
 
155
+ def _wait_for_merge(app_module, timeout=5):
156
+ """Wait until the background merge thread completes."""
157
+ deadline = time.monotonic() + timeout
158
+ while app_module.state["merging"] and time.monotonic() < deadline:
159
+ time.sleep(0.05)
160
+ assert not app_module.state["merging"], "merge did not complete in time"
161
+
162
+
163
  def test_full_round_merge_skipped_no_hf(client):
164
  tc, m = client
165
  secret = m.CONFIG["node_secret"]
 
180
  if nid != "node_c":
181
  assert body["status"] == "submitted"
182
  else:
183
+ # Last node triggers async merge — returns "merging"
184
+ assert body["status"] == "merging"
185
+
186
+ # Wait for the background merge thread to finish
187
+ _wait_for_merge(m)
188
 
189
  st = tc.get("/status").json()
190
  assert st["current_round"] == 2
 
283
  for nid in ("node_a", "node_b", "node_c"):
284
  tc.post("/submit", json={"node_id": nid, "secret_key": secret})
285
 
286
+ _wait_for_merge(m)
287
+
288
  m._loss_history_md()
289
  m._merged_adapters_md()
290
  m._activity_log_md()