RedHotTensors commited on
Commit
29de1ae
·
1 Parent(s): 2c95aa5

Synchronously flush votes before reloading stats to prevent out-of-date reloading.

Browse files
Files changed (2) hide show
  1. app.py +47 -55
  2. storage.py +39 -20
app.py CHANGED
@@ -40,70 +40,65 @@ WINRATE_LOC = _pool_df.columns.get_loc("winrate")
40
 
41
  _md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) }
42
 
43
- _stats_lock = threading.Lock()
44
  _pool_lock = threading.Lock()
45
 
46
  _stats_last_loaded_at = 0.0
47
  _explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
48
 
49
- def _reload_stats_if_due(force: bool = False):
50
- global _stats_last_loaded_at,_explorer_df
51
- now = time.time()
52
 
53
- if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
54
- return
 
55
 
56
- with _stats_lock:
57
- now = time.time()
 
 
 
 
 
 
58
 
59
- if not force and (now - _stats_last_loaded_at) < STATS_RELOAD_S:
60
- return
61
 
62
- stats_by_key = load_stats_by_md5(
63
- repo_id=POOL_REPO_ID,
64
- token=RATINGS_APP_TOKEN,
65
- )
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- n_missing = 0
68
- with _pool_lock:
69
- for md5, stats in stats_by_key.items():
70
- if (idx := _md5_to_idx.get(md5)) is not None:
71
- _pool_df.iloc[idx, [WINS_LOC, LOSSES_LOC, TIES_LOC, VOTES_LOC, WINRATE_LOC]] = (
72
- stats.wins, stats.losses, stats.ties, stats.votes, stats.winrate
73
- )
74
- else:
75
- n_missing += 1
76
-
77
- if n_missing:
78
- print(f"{n_missing} md5s have stats but are not in the pool!", file=sys.stderr)
79
-
80
- classifier_scores_path = hf_hub_download(
81
- repo_id=POOL_REPO_ID,
82
- filename="classifier_scores.parquet",
83
- repo_type="dataset",
84
- token=RATINGS_APP_TOKEN,
85
- )
86
- validation_set_path = hf_hub_download(
87
- repo_id=POOL_REPO_ID,
88
- filename="validation_set.parquet",
89
- repo_type="dataset",
90
- token=RATINGS_APP_TOKEN,
91
- )
92
- validation_df = pd.read_parquet(
93
- validation_set_path,
94
- columns=["group", "id", "md5", "rating", "sample_url", "image_url"],
95
- )
96
- classifier_scores_df = pd.read_parquet(classifier_scores_path)
97
- assert {"classifier", "md5", "classifier_score", "percentile"}.issubset(classifier_scores_df.columns), "classifier_scores.parquet missing expected columns"
98
- classifier_scores_df = classifier_scores_df[["classifier", "md5", "classifier_score", "percentile"]]
99
- classifier_scores_df["classifier"] = classifier_scores_df["classifier"].astype(str)
100
- classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str)
101
- validation_df["md5"] = validation_df["md5"].astype(str)
102
- _explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many")
103
- _stats_last_loaded_at = now
104
 
 
 
105
 
106
- _reload_stats_if_due(force=True)
 
 
 
 
 
 
107
 
108
  def _pick_from(df: pd.DataFrame, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None:
109
  if len(df) < 2:
@@ -247,8 +242,6 @@ def _format_rating_post_title(post_id: int, votes: int, label: str) -> str:
247
  return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}"
248
 
249
  def _render_current(state: dict, submit_status: str = "") -> tuple:
250
- _reload_stats_if_due()
251
-
252
  votes_a = _pool_df.iloc[_md5_to_idx[state["key_a"]], VOTES_LOC]
253
  votes_b = _pool_df.iloc[_md5_to_idx[state["key_b"]], VOTES_LOC]
254
  title_a = _format_rating_post_title(state["id_a"], votes_a, "Image A")
@@ -298,7 +291,6 @@ def _load_results(rating_pref_value: str, sort_mode_value: str, classifier_filte
298
  rating_pref = _normalize_rating_pref(rating_pref_value)
299
  sort_mode = _normalize_sort_mode(sort_mode_value)
300
  classifier_name = _normalize_classifier_filter(classifier_filter_value)
301
- _reload_stats_if_due()
302
  filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
303
  summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
304
  filtered_explorer_df,
 
40
 
41
  _md5_to_idx = { md5: idx for idx, md5 in enumerate(_pool_df["md5"]) }
42
 
 
43
  _pool_lock = threading.Lock()
44
 
45
  _stats_last_loaded_at = 0.0
46
  _explorer_df = pd.DataFrame(columns=["group", "id", "md5", "rating", "sample_url", "image_url", "classifier", "classifier_score", "percentile"])
47
 
48
+ def _load_stats() -> None:
49
+ VOTE_STORAGE.sync()
50
+ load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN)
51
 
52
+ n_missing = 0
53
+ with _pool_lock:
54
+ VOTE_STORAGE.sync()
55
 
56
+ stats_by_key = load_stats_by_md5(repo_id=POOL_REPO_ID, token=RATINGS_APP_TOKEN)
57
+ for md5, stats in stats_by_key.items():
58
+ if (idx := _md5_to_idx.get(md5)) is not None:
59
+ _pool_df.iloc[idx, [WINS_LOC, LOSSES_LOC, TIES_LOC, VOTES_LOC, WINRATE_LOC]] = (
60
+ stats.wins, stats.losses, stats.ties, stats.votes, stats.winrate
61
+ )
62
+ else:
63
+ n_missing += 1
64
 
65
+ if n_missing:
66
+ print(f"{n_missing} md5s have stats but are not in the pool!", file=sys.stderr)
67
 
68
+ classifier_scores_path = hf_hub_download(
69
+ repo_id=POOL_REPO_ID,
70
+ filename="classifier_scores.parquet",
71
+ repo_type="dataset",
72
+ token=RATINGS_APP_TOKEN,
73
+ )
74
+ validation_set_path = hf_hub_download(
75
+ repo_id=POOL_REPO_ID,
76
+ filename="validation_set.parquet",
77
+ repo_type="dataset",
78
+ token=RATINGS_APP_TOKEN,
79
+ )
80
+ validation_df = pd.read_parquet(
81
+ validation_set_path,
82
+ columns=["group", "id", "md5", "rating", "sample_url", "image_url"],
83
+ )
84
 
85
+ classifier_scores_df = pd.read_parquet(classifier_scores_path)
86
+ assert {"classifier", "md5", "classifier_score", "percentile"}.issubset(classifier_scores_df.columns), "classifier_scores.parquet missing expected columns"
87
+ classifier_scores_df = classifier_scores_df[["classifier", "md5", "classifier_score", "percentile"]]
88
+ classifier_scores_df["classifier"] = classifier_scores_df["classifier"].astype(str)
89
+ classifier_scores_df["md5"] = classifier_scores_df["md5"].astype(str)
90
+ validation_df["md5"] = validation_df["md5"].astype(str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ global _explorer_df
93
+ _explorer_df = validation_df.merge(classifier_scores_df, on="md5", how="left", validate="one_to_many")
94
 
95
+ def _stats_reloader() -> None:
96
+ while True:
97
+ time.sleep(STATS_RELOAD_S)
98
+ _load_stats()
99
+
100
+ _load_stats()
101
+ threading.Thread(target=_stats_reloader, daemon=True).start()
102
 
103
  def _pick_from(df: pd.DataFrame, weights: pd.Series | None = None) -> tuple[pd.Series, pd.Series, int] | None:
104
  if len(df) < 2:
 
242
  return f"<strong>{label}</strong>: <a href=\"https://e621.net/posts/{post_id}\" target=\"_blank\" rel=\"noreferrer\">Post #{post_id}</a> | {votes} {'Vote' if votes == 1 else 'Votes'}"
243
 
244
  def _render_current(state: dict, submit_status: str = "") -> tuple:
 
 
245
  votes_a = _pool_df.iloc[_md5_to_idx[state["key_a"]], VOTES_LOC]
246
  votes_b = _pool_df.iloc[_md5_to_idx[state["key_b"]], VOTES_LOC]
247
  title_a = _format_rating_post_title(state["id_a"], votes_a, "Image A")
 
291
  rating_pref = _normalize_rating_pref(rating_pref_value)
292
  sort_mode = _normalize_sort_mode(sort_mode_value)
293
  classifier_name = _normalize_classifier_filter(classifier_filter_value)
 
294
  filtered_explorer_df = _filtered_explorer_df_by_classifier(rating_pref, classifier_name)
295
  summary, score_distribution_plot, distribution_data, gallery_items, page_meta, next_offset, btn_update = build_results_data(
296
  filtered_explorer_df,
storage.py CHANGED
@@ -27,49 +27,63 @@ VOTE_COLUMNS = [
27
 
28
 
29
  class VoteStorage:
30
- def __init__(self, mode: str, token: str | None = None):
31
  assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
32
  self.mode = mode
33
  is_debug_mode = self.mode == "void"
34
 
35
  self._flush_every = 3 if is_debug_mode else 50
36
  self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
 
37
 
38
- self._shutdown = False
 
 
39
  self._votes_buffer: list[dict] = []
 
40
 
41
- self._flush_condition = threading.Condition(threading.Lock())
42
  self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
43
  self._flush_thread.start()
44
 
45
- self.hf_api = HfApi(token=token)
46
-
47
  atexit.register(self.close)
48
 
49
- def _empty_votes_df(self) -> pd.DataFrame:
50
- return pd.DataFrame(columns=VOTE_COLUMNS)
51
-
52
- def _upload_votes_batch(self, df: pd.DataFrame, commit_message: str):
53
  if self.mode == "void":
54
  return
55
 
 
 
 
 
 
 
 
 
56
  ts = int(time.time())
57
  shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
58
 
59
- self.hf_api.upload_file(
60
  path_or_fileobj=df.to_parquet(index=False),
61
  path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
62
  repo_id=VOTES_REPO_ID,
63
  repo_type=VOTES_REPO_TYPE,
64
- commit_message=commit_message,
65
  )
66
 
67
  def _flush_loop(self) -> None:
68
  while True:
69
  with self._flush_condition:
70
  while True:
 
 
 
 
 
 
 
 
71
  if self._shutdown:
72
- # Flush last batch of votes.
73
  if self._votes_buffer:
74
  break
75
 
@@ -89,16 +103,21 @@ class VoteStorage:
89
  batch = self._votes_buffer
90
  self._votes_buffer = []
91
 
92
- assert batch
93
- batch_df = pd.DataFrame(batch)
94
- del batch
 
 
 
 
 
 
95
 
96
- for col in VOTE_COLUMNS:
97
- if col not in batch_df.columns:
98
- batch_df[col] = None
99
 
100
- batch_df = batch_df[VOTE_COLUMNS]
101
- self._upload_votes_batch(batch_df, commit_message=f"upload {len(batch_df)} vote rows")
102
 
103
  def close(self) -> None:
104
  with self._flush_condition:
 
27
 
28
 
29
  class VoteStorage:
30
+ def __init__(self, mode: str, token: str | None = None) -> None:
31
  assert mode in ("hf", "void"), f"Unsupported storage mode: {mode}"
32
  self.mode = mode
33
  is_debug_mode = self.mode == "void"
34
 
35
  self._flush_every = 3 if is_debug_mode else 50
36
  self._flush_interval_sec = 15.0 if is_debug_mode else 300.0
37
+ self._hf_api = HfApi(token=token)
38
 
39
+ self._flush_condition = threading.Condition(threading.Lock())
40
+ self._sync_event = threading.Event()
41
+ self._sync_lock = threading.Lock()
42
  self._votes_buffer: list[dict] = []
43
+ self._shutdown = False
44
 
 
45
  self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
46
  self._flush_thread.start()
47
 
 
 
48
  atexit.register(self.close)
49
 
50
+ def _upload_votes_batch(self, batch: list[dict]) -> None:
51
+ assert batch
 
 
52
  if self.mode == "void":
53
  return
54
 
55
+ df = pd.DataFrame(batch)
56
+
57
+ for col in VOTE_COLUMNS:
58
+ if col not in df.columns:
59
+ df[col] = None
60
+
61
+ df = df[VOTE_COLUMNS]
62
+
63
  ts = int(time.time())
64
  shard = f"votes_{ts}_{uuid.uuid4().hex}.parquet"
65
 
66
+ self._hf_api.upload_file(
67
  path_or_fileobj=df.to_parquet(index=False),
68
  path_in_repo=f"{VOTES_LOG_SUBDIR}/{shard}",
69
  repo_id=VOTES_REPO_ID,
70
  repo_type=VOTES_REPO_TYPE,
71
+ commit_message=f"upload {len(df)} vote rows",
72
  )
73
 
74
  def _flush_loop(self) -> None:
75
  while True:
76
  with self._flush_condition:
77
  while True:
78
+ # Forced sync.
79
+ if not self._sync_event.is_set():
80
+ if self._votes_buffer:
81
+ break
82
+
83
+ self._sync_event.set()
84
+
85
+ # Shutdown wanted.
86
  if self._shutdown:
 
87
  if self._votes_buffer:
88
  break
89
 
 
103
  batch = self._votes_buffer
104
  self._votes_buffer = []
105
 
106
+ self._upload_votes_batch(batch)
107
+
108
+ def sync(self) -> None:
109
+ with self._sync_lock:
110
+ with self._flush_condition:
111
+ is_shutdown = self._shutdown
112
+ if not is_shutdown:
113
+ self._sync_event.clear()
114
+ self._flush_condition.notify()
115
 
116
+ if not is_shutdown:
117
+ self._sync_event.wait()
 
118
 
119
+ if is_shutdown:
120
+ self._flush_thread.join()
121
 
122
  def close(self) -> None:
123
  with self._flush_condition: