renminwansui1976 commited on
Commit
c4ead6a
·
unverified ·
1 Parent(s): b2eb544

更新 main.rs

Browse files
Files changed (1) hide show
  1. src/main.rs +43 -106
src/main.rs CHANGED
@@ -14,9 +14,8 @@ use tokio::process::{Child, Command};
14
  use tokio::time::{interval, timeout};
15
  use walkdir::WalkDir;
16
 
17
- /// The OpenClaw config directory we sync to/from HuggingFace.
18
- /// Must match the path OpenClaw itself writes to (HOME=/home/node).
19
- const WORKSPACE_DIR: &str = "/home/node/.openclaw";
20
 
21
  const STATE_FILE: &str = ".hf-sync-state.json";
22
  const FINAL_WAIT_TIMEOUT: Duration = Duration::from_secs(20);
@@ -24,8 +23,6 @@ const FINAL_WAIT_TIMEOUT: Duration = Duration::from_secs(20);
24
  #[derive(Debug, Clone)]
25
  struct Config {
26
  token: String,
27
- /// HuggingFace dataset repo in the form "owner/name".
28
- /// Read from OPENCLAW_DATASET_REPO (canonical HuggingClaw variable).
29
  dataset_id: String,
30
  sync_interval: Duration,
31
  workspace: PathBuf,
@@ -115,7 +112,7 @@ async fn run() -> Result<()> {
115
  }
116
  }
117
  _ = sigterm.recv() => {
118
- eprintln!("received SIGTERM, forwarding and running final sync");
119
  forward_sigterm(&mut child);
120
  if let Err(err) = push_workspace(&client, &cfg).await {
121
  eprintln!("final push failed: {err:#}");
@@ -125,15 +122,12 @@ async fn run() -> Result<()> {
125
  }
126
  status = child.wait() => {
127
  match status {
128
- Ok(status) => {
129
- eprintln!("child exited with status: {status}");
130
  if let Err(err) = push_workspace(&client, &cfg).await {
131
  eprintln!("final push after child exit failed: {err:#}");
132
  }
133
- match status.code() {
134
- Some(code) => exit(code),
135
- None => exit(1),
136
- }
137
  }
138
  Err(err) => {
139
  eprintln!("failed waiting for child: {err:#}");
@@ -150,8 +144,7 @@ async fn run() -> Result<()> {
150
  fn load_config() -> Result<Config> {
151
  let token = env::var("HF_TOKEN").context("HF_TOKEN is required")?;
152
 
153
- // Accept OPENCLAW_DATASET_REPO (canonical HuggingClaw name) with
154
- // HF_DATASET_ID as a legacy fallback so old Space configs keep working.
155
  let dataset_id = env::var("OPENCLAW_DATASET_REPO")
156
  .or_else(|_| env::var("HF_DATASET_ID"))
157
  .context("OPENCLAW_DATASET_REPO (or HF_DATASET_ID) is required")?;
@@ -193,7 +186,7 @@ fn build_client(cfg: &Config) -> Result<reqwest::Client> {
193
  AUTHORIZATION,
194
  format!("Bearer {}", cfg.token)
195
  .parse()
196
- .context("invalid HF token for auth header")?,
197
  );
198
  headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
199
 
@@ -204,34 +197,27 @@ fn build_client(cfg: &Config) -> Result<reqwest::Client> {
204
  }
205
 
206
  async fn ensure_dataset_exists(client: &reqwest::Client, cfg: &Config) -> Result<()> {
207
- let lookup_url = format!("https://huggingface.co/api/datasets/{}", cfg.dataset_id);
208
- let response = client.get(&lookup_url).send().await?;
209
 
210
  match response.status() {
211
  StatusCode::OK => {
212
  let repo: RepoLookup = response.json().await.unwrap_or(RepoLookup { id: None });
213
- eprintln!(
214
- "dataset exists: {}",
215
- repo.id.unwrap_or(cfg.dataset_id.clone())
216
- );
217
  Ok(())
218
  }
219
  StatusCode::NOT_FOUND => {
220
- // Only auto-create if the user has opted in.
221
  let auto_create = env::var("AUTO_CREATE_DATASET")
222
  .map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
223
  .unwrap_or(false);
224
 
225
  if auto_create {
226
- eprintln!(
227
- "dataset {} not found, AUTO_CREATE_DATASET=true — creating private dataset",
228
- cfg.dataset_id
229
- );
230
  create_private_dataset(client, cfg).await
231
  } else {
232
  Err(anyhow!(
233
- "dataset {} not found. Create it on huggingface.co/new-dataset \
234
- or set AUTO_CREATE_DATASET=true to create it automatically.",
235
  cfg.dataset_id
236
  ))
237
  }
@@ -249,9 +235,8 @@ async fn create_private_dataset(client: &reqwest::Client, cfg: &Config) -> Resul
249
  .split_once('/')
250
  .ok_or_else(|| anyhow!("OPENCLAW_DATASET_REPO must be in the form owner/name"))?;
251
 
252
- let me_url = "https://huggingface.co/api/whoami-v2";
253
  let username = client
254
- .get(me_url)
255
  .send()
256
  .await?
257
  .error_for_status()?
@@ -264,11 +249,7 @@ async fn create_private_dataset(client: &reqwest::Client, cfg: &Config) -> Resul
264
 
265
  let req = CreateRepoRequest {
266
  name: name.to_string(),
267
- organization: if owner == username {
268
- None
269
- } else {
270
- Some(owner.to_string())
271
- },
272
  private: true,
273
  repo_type: "dataset".to_string(),
274
  };
@@ -289,22 +270,15 @@ async fn create_private_dataset(client: &reqwest::Client, cfg: &Config) -> Resul
289
  }
290
 
291
  async fn pull_workspace(client: &reqwest::Client, cfg: &Config) -> Result<()> {
292
- eprintln!("pulling workspace from HuggingFace dataset: {}", cfg.dataset_id);
293
  let remote_files = list_remote_files(client, cfg).await?;
294
 
295
  for file in &remote_files {
296
- let download_url = format!(
297
  "https://huggingface.co/datasets/{}/resolve/main/{}",
298
  cfg.dataset_id, file
299
  );
300
- let bytes = client
301
- .get(download_url)
302
- .send()
303
- .await?
304
- .error_for_status()?
305
- .bytes()
306
- .await?;
307
-
308
  let target = cfg.workspace.join(file);
309
  if let Some(parent) = target.parent() {
310
  tokio::fs::create_dir_all(parent).await?;
@@ -326,18 +300,12 @@ async fn list_remote_files(client: &reqwest::Client, cfg: &Config) -> Result<Vec
326
  if response.status() == StatusCode::NOT_FOUND {
327
  return Ok(Vec::new());
328
  }
329
-
330
- let response = response.error_for_status()?;
331
- let entries: Vec<TreeEntry> = response.json().await?;
332
- Ok(entries
333
- .into_iter()
334
- .filter(|e| e.kind == "file")
335
- .map(|e| e.path)
336
- .collect())
337
  }
338
 
339
  async fn push_workspace(client: &reqwest::Client, cfg: &Config) -> Result<()> {
340
- eprintln!("pushing workspace updates to HuggingFace");
341
  let state_path = cfg.workspace.join(STATE_FILE);
342
  let mut state = load_state(&state_path).await?;
343
 
@@ -363,10 +331,6 @@ async fn push_workspace(client: &reqwest::Client, cfg: &Config) -> Result<()> {
363
  let bytes = tokio::fs::read(full_path).await?;
364
  let md5 = format!("{:x}", md5::compute(&bytes));
365
  let size = bytes.len() as u64;
366
- let file_state = FileState {
367
- md5: md5.clone(),
368
- size,
369
- };
370
 
371
  let changed = state
372
  .files
@@ -375,23 +339,20 @@ async fn push_workspace(client: &reqwest::Client, cfg: &Config) -> Result<()> {
375
  .unwrap_or(true);
376
 
377
  if changed {
378
- let content = base64::engine::general_purpose::STANDARD.encode(bytes);
379
  operations.push(CommitOperation::AddOrUpdate {
380
  path: relative.clone(),
381
  encoding: "base64".to_string(),
382
- content,
383
  });
384
  }
385
 
386
- current.insert(relative, file_state);
387
  }
388
 
389
  let old_paths: HashSet<_> = state.files.keys().cloned().collect();
390
  let new_paths: HashSet<_> = current.keys().cloned().collect();
391
  for removed in old_paths.difference(&new_paths) {
392
- operations.push(CommitOperation::Delete {
393
- path: removed.clone(),
394
- });
395
  }
396
 
397
  if operations.is_empty() {
@@ -418,7 +379,6 @@ async fn push_workspace(client: &reqwest::Client, cfg: &Config) -> Result<()> {
418
  state.files = current;
419
  save_state(&state_path, &state).await?;
420
  eprintln!("push complete");
421
-
422
  Ok(())
423
  }
424
 
@@ -435,21 +395,16 @@ async fn rebuild_sync_state(workspace: &Path) -> Result<()> {
435
  if full_path == state_path {
436
  continue;
437
  }
438
-
439
  let relative = full_path
440
  .strip_prefix(workspace)
441
  .context("failed to strip workspace prefix")?
442
  .to_string_lossy()
443
  .replace('\\', "/");
444
  let bytes = tokio::fs::read(full_path).await?;
445
-
446
- files.insert(
447
- relative,
448
- FileState {
449
- md5: format!("{:x}", md5::compute(&bytes)),
450
- size: bytes.len() as u64,
451
- },
452
- );
453
  }
454
 
455
  save_state(&state_path, &SyncState { files }).await
@@ -460,13 +415,11 @@ async fn load_state(path: &Path) -> Result<SyncState> {
460
  return Ok(SyncState::default());
461
  }
462
  let raw = tokio::fs::read(path).await?;
463
- let state = serde_json::from_slice(&raw).context("failed to parse sync state file")?;
464
- Ok(state)
465
  }
466
 
467
  async fn save_state(path: &Path, state: &SyncState) -> Result<()> {
468
- let raw = serde_json::to_vec_pretty(state)?;
469
- tokio::fs::write(path, raw).await?;
470
  Ok(())
471
  }
472
 
@@ -474,55 +427,39 @@ fn spawn_child_from_args() -> Result<Child> {
474
  let args: Vec<String> = env::args().skip(1).collect();
475
  if args.is_empty() {
476
  return Err(anyhow!(
477
- "no command provided — pass the main process command as entrypoint arguments"
478
  ));
479
  }
480
-
481
- let mut command = Command::new(&args[0]);
482
- command
483
  .args(&args[1..])
484
  .stdin(Stdio::inherit())
485
  .stdout(Stdio::inherit())
486
- .stderr(Stdio::inherit());
487
-
488
- command.spawn().context("failed to spawn child process")
489
  }
490
 
491
- /// Forward SIGTERM to the child process using libc::kill.
492
  fn forward_sigterm(child: &mut Child) {
493
  if let Some(id) = child.id() {
494
- // SAFETY: id is a valid pid obtained from a live child process.
495
  let ret = unsafe { libc::kill(id as libc::pid_t, libc::SIGTERM) };
496
  if ret != 0 {
497
- eprintln!(
498
- "failed to forward SIGTERM: {}",
499
- std::io::Error::last_os_error()
500
- );
501
  }
502
  }
503
  }
504
 
505
  async fn wait_for_child_shutdown(child: &mut Child) {
506
  match timeout(FINAL_WAIT_TIMEOUT, child.wait()).await {
507
- Ok(Ok(status)) => {
508
- eprintln!("child exited after SIGTERM with status: {status}");
509
- }
510
- Ok(Err(err)) => {
511
- eprintln!("failed waiting for child after SIGTERM: {err:#}");
512
- }
513
  Err(_) => {
514
- eprintln!(
515
- "child did not exit within {:?}, sending SIGKILL",
516
- FINAL_WAIT_TIMEOUT
517
- );
518
  if let Some(id) = child.id() {
519
- // SAFETY: id is a valid pid obtained from a live child process.
520
  let ret = unsafe { libc::kill(id as libc::pid_t, libc::SIGKILL) };
521
  if ret != 0 {
522
- eprintln!(
523
- "failed to SIGKILL child: {}",
524
- std::io::Error::last_os_error()
525
- );
526
  }
527
  }
528
  }
 
14
  use tokio::time::{interval, timeout};
15
  use walkdir::WalkDir;
16
 
17
+ /// The OpenClaw config directory synced to/from HuggingFace.
18
+ const WORKSPACE_DIR: &str = "/home/user/.openclaw";
 
19
 
20
  const STATE_FILE: &str = ".hf-sync-state.json";
21
  const FINAL_WAIT_TIMEOUT: Duration = Duration::from_secs(20);
 
23
  #[derive(Debug, Clone)]
24
  struct Config {
25
  token: String,
 
 
26
  dataset_id: String,
27
  sync_interval: Duration,
28
  workspace: PathBuf,
 
112
  }
113
  }
114
  _ = sigterm.recv() => {
115
+ eprintln!("received SIGTERM forwarding and running final sync");
116
  forward_sigterm(&mut child);
117
  if let Err(err) = push_workspace(&client, &cfg).await {
118
  eprintln!("final push failed: {err:#}");
 
122
  }
123
  status = child.wait() => {
124
  match status {
125
+ Ok(s) => {
126
+ eprintln!("child exited with status: {s}");
127
  if let Err(err) = push_workspace(&client, &cfg).await {
128
  eprintln!("final push after child exit failed: {err:#}");
129
  }
130
+ exit(s.code().unwrap_or(1));
 
 
 
131
  }
132
  Err(err) => {
133
  eprintln!("failed waiting for child: {err:#}");
 
144
  fn load_config() -> Result<Config> {
145
  let token = env::var("HF_TOKEN").context("HF_TOKEN is required")?;
146
 
147
+ // Accept OPENCLAW_DATASET_REPO (canonical) with HF_DATASET_ID as fallback.
 
148
  let dataset_id = env::var("OPENCLAW_DATASET_REPO")
149
  .or_else(|_| env::var("HF_DATASET_ID"))
150
  .context("OPENCLAW_DATASET_REPO (or HF_DATASET_ID) is required")?;
 
186
  AUTHORIZATION,
187
  format!("Bearer {}", cfg.token)
188
  .parse()
189
+ .context("invalid HF_TOKEN for auth header")?,
190
  );
191
  headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
192
 
 
197
  }
198
 
199
  async fn ensure_dataset_exists(client: &reqwest::Client, cfg: &Config) -> Result<()> {
200
+ let url = format!("https://huggingface.co/api/datasets/{}", cfg.dataset_id);
201
+ let response = client.get(&url).send().await?;
202
 
203
  match response.status() {
204
  StatusCode::OK => {
205
  let repo: RepoLookup = response.json().await.unwrap_or(RepoLookup { id: None });
206
+ eprintln!("dataset exists: {}", repo.id.unwrap_or(cfg.dataset_id.clone()));
 
 
 
207
  Ok(())
208
  }
209
  StatusCode::NOT_FOUND => {
 
210
  let auto_create = env::var("AUTO_CREATE_DATASET")
211
  .map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes" | "on"))
212
  .unwrap_or(false);
213
 
214
  if auto_create {
215
+ eprintln!("dataset {} not found — AUTO_CREATE_DATASET=true, creating…", cfg.dataset_id);
 
 
 
216
  create_private_dataset(client, cfg).await
217
  } else {
218
  Err(anyhow!(
219
+ "dataset {} not found. Create it at huggingface.co/new-dataset \
220
+ or set AUTO_CREATE_DATASET=true.",
221
  cfg.dataset_id
222
  ))
223
  }
 
235
  .split_once('/')
236
  .ok_or_else(|| anyhow!("OPENCLAW_DATASET_REPO must be in the form owner/name"))?;
237
 
 
238
  let username = client
239
+ .get("https://huggingface.co/api/whoami-v2")
240
  .send()
241
  .await?
242
  .error_for_status()?
 
249
 
250
  let req = CreateRepoRequest {
251
  name: name.to_string(),
252
+ organization: if owner == username { None } else { Some(owner.to_string()) },
 
 
 
 
253
  private: true,
254
  repo_type: "dataset".to_string(),
255
  };
 
270
  }
271
 
272
  async fn pull_workspace(client: &reqwest::Client, cfg: &Config) -> Result<()> {
273
+ eprintln!("pulling workspace from HuggingFace: {}", cfg.dataset_id);
274
  let remote_files = list_remote_files(client, cfg).await?;
275
 
276
  for file in &remote_files {
277
+ let url = format!(
278
  "https://huggingface.co/datasets/{}/resolve/main/{}",
279
  cfg.dataset_id, file
280
  );
281
+ let bytes = client.get(url).send().await?.error_for_status()?.bytes().await?;
 
 
 
 
 
 
 
282
  let target = cfg.workspace.join(file);
283
  if let Some(parent) = target.parent() {
284
  tokio::fs::create_dir_all(parent).await?;
 
300
  if response.status() == StatusCode::NOT_FOUND {
301
  return Ok(Vec::new());
302
  }
303
+ let entries: Vec<TreeEntry> = response.error_for_status()?.json().await?;
304
+ Ok(entries.into_iter().filter(|e| e.kind == "file").map(|e| e.path).collect())
 
 
 
 
 
 
305
  }
306
 
307
  async fn push_workspace(client: &reqwest::Client, cfg: &Config) -> Result<()> {
308
+ eprintln!("pushing workspace to HuggingFace");
309
  let state_path = cfg.workspace.join(STATE_FILE);
310
  let mut state = load_state(&state_path).await?;
311
 
 
331
  let bytes = tokio::fs::read(full_path).await?;
332
  let md5 = format!("{:x}", md5::compute(&bytes));
333
  let size = bytes.len() as u64;
 
 
 
 
334
 
335
  let changed = state
336
  .files
 
339
  .unwrap_or(true);
340
 
341
  if changed {
 
342
  operations.push(CommitOperation::AddOrUpdate {
343
  path: relative.clone(),
344
  encoding: "base64".to_string(),
345
+ content: base64::engine::general_purpose::STANDARD.encode(&bytes),
346
  });
347
  }
348
 
349
+ current.insert(relative, FileState { md5, size });
350
  }
351
 
352
  let old_paths: HashSet<_> = state.files.keys().cloned().collect();
353
  let new_paths: HashSet<_> = current.keys().cloned().collect();
354
  for removed in old_paths.difference(&new_paths) {
355
+ operations.push(CommitOperation::Delete { path: removed.clone() });
 
 
356
  }
357
 
358
  if operations.is_empty() {
 
379
  state.files = current;
380
  save_state(&state_path, &state).await?;
381
  eprintln!("push complete");
 
382
  Ok(())
383
  }
384
 
 
395
  if full_path == state_path {
396
  continue;
397
  }
 
398
  let relative = full_path
399
  .strip_prefix(workspace)
400
  .context("failed to strip workspace prefix")?
401
  .to_string_lossy()
402
  .replace('\\', "/");
403
  let bytes = tokio::fs::read(full_path).await?;
404
+ files.insert(relative, FileState {
405
+ md5: format!("{:x}", md5::compute(&bytes)),
406
+ size: bytes.len() as u64,
407
+ });
 
 
 
 
408
  }
409
 
410
  save_state(&state_path, &SyncState { files }).await
 
415
  return Ok(SyncState::default());
416
  }
417
  let raw = tokio::fs::read(path).await?;
418
+ Ok(serde_json::from_slice(&raw).context("failed to parse sync state file")?)
 
419
  }
420
 
421
  async fn save_state(path: &Path, state: &SyncState) -> Result<()> {
422
+ tokio::fs::write(path, serde_json::to_vec_pretty(state)?).await?;
 
423
  Ok(())
424
  }
425
 
 
427
  let args: Vec<String> = env::args().skip(1).collect();
428
  if args.is_empty() {
429
  return Err(anyhow!(
430
+ "no command provided — pass the child command as entrypoint arguments"
431
  ));
432
  }
433
+ Command::new(&args[0])
 
 
434
  .args(&args[1..])
435
  .stdin(Stdio::inherit())
436
  .stdout(Stdio::inherit())
437
+ .stderr(Stdio::inherit())
438
+ .spawn()
439
+ .context("failed to spawn child process")
440
  }
441
 
 
442
  fn forward_sigterm(child: &mut Child) {
443
  if let Some(id) = child.id() {
444
+ // SAFETY: id is a valid pid from a live child process.
445
  let ret = unsafe { libc::kill(id as libc::pid_t, libc::SIGTERM) };
446
  if ret != 0 {
447
+ eprintln!("failed to forward SIGTERM: {}", std::io::Error::last_os_error());
 
 
 
448
  }
449
  }
450
  }
451
 
452
  async fn wait_for_child_shutdown(child: &mut Child) {
453
  match timeout(FINAL_WAIT_TIMEOUT, child.wait()).await {
454
+ Ok(Ok(s)) => eprintln!("child exited after SIGTERM: {s}"),
455
+ Ok(Err(e)) => eprintln!("error waiting for child after SIGTERM: {e:#}"),
 
 
 
 
456
  Err(_) => {
457
+ eprintln!("child did not exit within {:?} — sending SIGKILL", FINAL_WAIT_TIMEOUT);
 
 
 
458
  if let Some(id) = child.id() {
459
+ // SAFETY: id is a valid pid from a live child process.
460
  let ret = unsafe { libc::kill(id as libc::pid_t, libc::SIGKILL) };
461
  if ret != 0 {
462
+ eprintln!("failed to SIGKILL: {}", std::io::Error::last_os_error());
 
 
 
463
  }
464
  }
465
  }