Mirrowel commited on
Commit
90d4836
·
1 Parent(s): 06b3f7d

feat(quota): ✨ improve aggregation with tier priorities and fix double-counting

Browse files

- Adds tier priority metadata for proper sorting of credential tiers in the UI
- Fixes double-counting when models share quota groups by using aggregated group totals
- Enhances reset time display to show expiration for low/exhausted quotas
- Implements provider-specific stats merging to preserve cache during partial updates
- Recalculates summary statistics on-demand instead of full cache replacement

src/proxy_app/launcher_tui.py CHANGED
@@ -429,7 +429,7 @@ class LauncherTUI:
429
  self.console.print(" 3. 🔑 Manage Credentials")
430
 
431
  self.console.print(" 4. 📊 View Provider & Advanced Settings")
432
- self.console.print(" 5. 📈 View Quota & Usage Stats")
433
  self.console.print(" 6. 🔄 Reload Configuration")
434
  self.console.print(" 7. ℹ️ About")
435
  self.console.print(" 8. 🚪 Exit")
 
429
  self.console.print(" 3. 🔑 Manage Credentials")
430
 
431
  self.console.print(" 4. 📊 View Provider & Advanced Settings")
432
+ self.console.print(" 5. 📈 View Quota & Usage Stats (Alpha)")
433
  self.console.print(" 6. 🔄 Reload Configuration")
434
  self.console.print(" 7. ℹ️ About")
435
  self.console.print(" 8. 🚪 Exit")
src/proxy_app/quota_viewer.py CHANGED
@@ -3,6 +3,42 @@ Lightweight Quota Stats Viewer TUI.
3
 
4
  Connects to a running proxy to display quota and usage statistics.
5
  Uses only httpx + rich (no heavy rotator_library imports).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import os
@@ -257,6 +293,131 @@ class QuotaViewer:
257
  self.last_error = str(e)
258
  return None
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  def post_action(
261
  self,
262
  action: str,
@@ -300,7 +461,14 @@ class QuotaViewer:
300
  return None
301
 
302
  result = response.json()
303
- self.cached_stats = result
 
 
 
 
 
 
 
304
  self.last_error = None
305
  return result
306
 
@@ -424,8 +592,12 @@ class QuotaViewer:
424
  tiers = group_stats.get("tiers", {})
425
 
426
  # Format tier info: "5(15)f/2s" = 5 active out of 15 free, 2 standard all active
 
427
  tier_parts = []
428
- for tier_name, tier_info in sorted(tiers.items()):
 
 
 
429
  if tier_name == "unknown":
430
  continue # Skip unknown tiers in display
431
  total_t = tier_info.get("total", 0)
@@ -546,10 +718,13 @@ class QuotaViewer:
546
  valid_choices = [str(i) for i in range(1, len(provider_list) + 1)]
547
  valid_choices.extend(["r", "R", "s", "S", "m", "M", "b", "B", "g", "G"])
548
 
549
- choice = Prompt.ask("Select option", default="B").strip()
550
 
551
  if choice.lower() == "b":
552
  self.running = False
 
 
 
553
  elif choice.lower() == "g":
554
  # Toggle view mode
555
  self.view_mode = "global" if self.view_mode == "current" else "current"
@@ -659,6 +834,7 @@ class QuotaViewer:
659
  ):
660
  self.post_action("reload", scope="all")
661
  elif choice == "F" and has_quota_groups:
 
662
  with self.console.status(
663
  f"[bold]Fetching live quota for ALL {provider} credentials...",
664
  spinner="dots",
@@ -666,16 +842,17 @@ class QuotaViewer:
666
  result = self.post_action(
667
  "force_refresh", scope="provider", provider=provider
668
  )
669
- if result and result.get("refresh_result"):
670
- rr = result["refresh_result"]
671
- self.console.print(
672
- f"\n[green]Refreshed {rr.get('credentials_refreshed', 0)} credentials "
673
- f"in {rr.get('duration_ms', 0)}ms[/green]"
674
- )
675
- if rr.get("errors"):
676
- for err in rr["errors"]:
677
- self.console.print(f"[red] Error: {err}[/red]")
678
- Prompt.ask("Press Enter to continue", default="")
 
679
  elif choice.startswith("F") and choice[1:].isdigit() and has_quota_groups:
680
  idx = int(choice[1:])
681
  credentials = (
@@ -691,6 +868,7 @@ class QuotaViewer:
691
  cred = credentials[idx - 1]
692
  cred_id = cred.get("identifier", "")
693
  email = cred.get("email", cred_id)
 
694
  with self.console.status(
695
  f"[bold]Fetching live quota for {email}...", spinner="dots"
696
  ):
@@ -700,15 +878,16 @@ class QuotaViewer:
700
  provider=provider,
701
  credential=cred_id,
702
  )
703
- if result and result.get("refresh_result"):
704
- rr = result["refresh_result"]
705
- self.console.print(
706
- f"\n[green]Refreshed in {rr.get('duration_ms', 0)}ms[/green]"
707
- )
708
- if rr.get("errors"):
709
- for err in rr["errors"]:
710
- self.console.print(f"[red] Error: {err}[/red]")
711
- Prompt.ask("Press Enter to continue", default="")
 
712
 
713
  def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str):
714
  """Render a single credential as a panel."""
@@ -841,16 +1020,28 @@ class QuotaViewer:
841
  display = group_stats.get("display", f"{requests_used}/?")
842
  bar = create_progress_bar(remaining_pct)
843
 
 
 
 
844
  # Color based on status
845
  if is_exhausted:
846
  color = "red"
847
- status_text = "⛔ EXHAUSTED"
 
 
 
848
  elif remaining_pct is not None and remaining_pct < 20:
849
  color = "yellow"
850
- status_text = "⚠️ LOW"
 
 
 
851
  else:
852
  color = "green"
853
- status_text = f"Resets: {reset_time}"
 
 
 
854
 
855
  # Confidence indicator
856
  conf_indicator = ""
 
3
 
4
  Connects to a running proxy to display quota and usage statistics.
5
  Uses only httpx + rich (no heavy rotator_library imports).
6
+
7
+ TODO: Missing Features & Improvements
8
+ ======================================
9
+
10
+ Display Improvements:
11
+ - [ ] Add color legend/help screen explaining status colors and symbols
12
+ - [ ] Show credential email/project ID if available (currently just filename)
13
+ - [ ] Add keyboard shortcut hints (e.g., "Press ? for help")
14
+ - [ ] Support terminal resize / responsive layout
15
+
16
+ Global Stats Fix:
17
+ - [ ] HACK: Global requests currently set to current period requests only
18
+ (see client.py get_quota_stats). This doesn't include archived stats.
19
+ Fix requires tracking archived requests per quota group in usage_manager.py
20
+ to avoid double-counting models that share quota groups.
21
+
22
+ Data & Refresh:
23
+ - [ ] Auto-refresh option (configurable interval)
24
+ - [ ] Show last refresh timestamp more prominently
25
+ - [ ] Cache invalidation when switching between current/global view
26
+ - [ ] Support for non-OAuth providers (API keys like nvapi-*, gsk_*, etc.)
27
+
28
+ Remote Management:
29
+ - [ ] Test connection before saving remote
30
+ - [ ] Import/export remote configurations
31
+ - [ ] SSH tunnel support for remote proxies
32
+
33
+ Quota Groups:
34
+ - [ ] Show which models are in each quota group (expandable)
35
+ - [ ] Historical quota usage graphs (if data available)
36
+ - [ ] Alerts/notifications when quota is low
37
+
38
+ Credential Details:
39
+ - [ ] Show per-model breakdown within quota groups
40
+ - [ ] Edit credential priority/tier manually
41
+ - [ ] Disable/enable individual credentials
42
  """
43
 
44
  import os
 
293
  self.last_error = str(e)
294
  return None
295
 
296
+ def _merge_provider_stats(self, provider: str, result: Dict[str, Any]) -> None:
297
+ """
298
+ Merge provider-specific stats into the existing cache.
299
+
300
+ Updates just the specified provider's data and recalculates the
301
+ summary fields to reflect the change.
302
+
303
+ Args:
304
+ provider: Provider name that was refreshed
305
+ result: API response containing the refreshed provider data
306
+ """
307
+ if not self.cached_stats:
308
+ self.cached_stats = result
309
+ return
310
+
311
+ # Merge provider data
312
+ if "providers" in result and provider in result["providers"]:
313
+ if "providers" not in self.cached_stats:
314
+ self.cached_stats["providers"] = {}
315
+ self.cached_stats["providers"][provider] = result["providers"][provider]
316
+
317
+ # Update timestamp
318
+ if "timestamp" in result:
319
+ self.cached_stats["timestamp"] = result["timestamp"]
320
+
321
+ # Recalculate summary from all providers
322
+ self._recalculate_summary()
323
+
324
+ def _recalculate_summary(self) -> None:
325
+ """
326
+ Recalculate summary fields from all provider data in cache.
327
+
328
+ Updates both 'summary' and 'global_summary' based on current
329
+ provider stats.
330
+ """
331
+ providers = self.cached_stats.get("providers", {})
332
+ if not providers:
333
+ return
334
+
335
+ # Calculate summary from all providers
336
+ total_creds = 0
337
+ active_creds = 0
338
+ exhausted_creds = 0
339
+ total_requests = 0
340
+ total_input_cached = 0
341
+ total_input_uncached = 0
342
+ total_output = 0
343
+ total_cost = 0.0
344
+
345
+ for prov_stats in providers.values():
346
+ total_creds += prov_stats.get("credential_count", 0)
347
+ active_creds += prov_stats.get("active_count", 0)
348
+ exhausted_creds += prov_stats.get("exhausted_count", 0)
349
+ total_requests += prov_stats.get("total_requests", 0)
350
+
351
+ tokens = prov_stats.get("tokens", {})
352
+ total_input_cached += tokens.get("input_cached", 0)
353
+ total_input_uncached += tokens.get("input_uncached", 0)
354
+ total_output += tokens.get("output", 0)
355
+
356
+ cost = prov_stats.get("approx_cost")
357
+ if cost:
358
+ total_cost += cost
359
+
360
+ total_input = total_input_cached + total_input_uncached
361
+ input_cache_pct = (
362
+ round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0
363
+ )
364
+
365
+ self.cached_stats["summary"] = {
366
+ "total_providers": len(providers),
367
+ "total_credentials": total_creds,
368
+ "active_credentials": active_creds,
369
+ "exhausted_credentials": exhausted_creds,
370
+ "total_requests": total_requests,
371
+ "tokens": {
372
+ "input_cached": total_input_cached,
373
+ "input_uncached": total_input_uncached,
374
+ "input_cache_pct": input_cache_pct,
375
+ "output": total_output,
376
+ },
377
+ "approx_total_cost": total_cost if total_cost > 0 else None,
378
+ }
379
+
380
+ # Also recalculate global_summary if it exists
381
+ if "global_summary" in self.cached_stats:
382
+ global_total_requests = 0
383
+ global_input_cached = 0
384
+ global_input_uncached = 0
385
+ global_output = 0
386
+ global_cost = 0.0
387
+
388
+ for prov_stats in providers.values():
389
+ global_data = prov_stats.get("global", prov_stats)
390
+ global_total_requests += global_data.get("total_requests", 0)
391
+
392
+ tokens = global_data.get("tokens", {})
393
+ global_input_cached += tokens.get("input_cached", 0)
394
+ global_input_uncached += tokens.get("input_uncached", 0)
395
+ global_output += tokens.get("output", 0)
396
+
397
+ cost = global_data.get("approx_cost")
398
+ if cost:
399
+ global_cost += cost
400
+
401
+ global_total_input = global_input_cached + global_input_uncached
402
+ global_cache_pct = (
403
+ round(global_input_cached / global_total_input * 100, 1)
404
+ if global_total_input > 0
405
+ else 0
406
+ )
407
+
408
+ self.cached_stats["global_summary"] = {
409
+ "total_providers": len(providers),
410
+ "total_credentials": total_creds,
411
+ "total_requests": global_total_requests,
412
+ "tokens": {
413
+ "input_cached": global_input_cached,
414
+ "input_uncached": global_input_uncached,
415
+ "input_cache_pct": global_cache_pct,
416
+ "output": global_output,
417
+ },
418
+ "approx_total_cost": global_cost if global_cost > 0 else None,
419
+ }
420
+
421
  def post_action(
422
  self,
423
  action: str,
 
461
  return None
462
 
463
  result = response.json()
464
+
465
+ # If scope is provider-specific, merge into existing cache
466
+ if scope == "provider" and provider and self.cached_stats:
467
+ self._merge_provider_stats(provider, result)
468
+ else:
469
+ # Full refresh - replace everything
470
+ self.cached_stats = result
471
+
472
  self.last_error = None
473
  return result
474
 
 
592
  tiers = group_stats.get("tiers", {})
593
 
594
  # Format tier info: "5(15)f/2s" = 5 active out of 15 free, 2 standard all active
595
+ # Sort by priority (lower number = higher priority, appears first)
596
  tier_parts = []
597
+ sorted_tiers = sorted(
598
+ tiers.items(), key=lambda x: x[1].get("priority", 10)
599
+ )
600
+ for tier_name, tier_info in sorted_tiers:
601
  if tier_name == "unknown":
602
  continue # Skip unknown tiers in display
603
  total_t = tier_info.get("total", 0)
 
718
  valid_choices = [str(i) for i in range(1, len(provider_list) + 1)]
719
  valid_choices.extend(["r", "R", "s", "S", "m", "M", "b", "B", "g", "G"])
720
 
721
+ choice = Prompt.ask("Select option", default="").strip()
722
 
723
  if choice.lower() == "b":
724
  self.running = False
725
+ elif choice == "":
726
+ # Empty input - just refresh the screen
727
+ pass
728
  elif choice.lower() == "g":
729
  # Toggle view mode
730
  self.view_mode = "global" if self.view_mode == "current" else "current"
 
834
  ):
835
  self.post_action("reload", scope="all")
836
  elif choice == "F" and has_quota_groups:
837
+ result = None
838
  with self.console.status(
839
  f"[bold]Fetching live quota for ALL {provider} credentials...",
840
  spinner="dots",
 
842
  result = self.post_action(
843
  "force_refresh", scope="provider", provider=provider
844
  )
845
+ # Handle result OUTSIDE spinner
846
+ if result and result.get("refresh_result"):
847
+ rr = result["refresh_result"]
848
+ self.console.print(
849
+ f"\n[green]Refreshed {rr.get('credentials_refreshed', 0)} credentials "
850
+ f"in {rr.get('duration_ms', 0)}ms[/green]"
851
+ )
852
+ if rr.get("errors"):
853
+ for err in rr["errors"]:
854
+ self.console.print(f"[red] Error: {err}[/red]")
855
+ Prompt.ask("Press Enter to continue", default="")
856
  elif choice.startswith("F") and choice[1:].isdigit() and has_quota_groups:
857
  idx = int(choice[1:])
858
  credentials = (
 
868
  cred = credentials[idx - 1]
869
  cred_id = cred.get("identifier", "")
870
  email = cred.get("email", cred_id)
871
+ result = None
872
  with self.console.status(
873
  f"[bold]Fetching live quota for {email}...", spinner="dots"
874
  ):
 
878
  provider=provider,
879
  credential=cred_id,
880
  )
881
+ # Handle result OUTSIDE spinner
882
+ if result and result.get("refresh_result"):
883
+ rr = result["refresh_result"]
884
+ self.console.print(
885
+ f"\n[green]Refreshed in {rr.get('duration_ms', 0)}ms[/green]"
886
+ )
887
+ if rr.get("errors"):
888
+ for err in rr["errors"]:
889
+ self.console.print(f"[red] Error: {err}[/red]")
890
+ Prompt.ask("Press Enter to continue", default="")
891
 
892
  def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str):
893
  """Render a single credential as a panel."""
 
1020
  display = group_stats.get("display", f"{requests_used}/?")
1021
  bar = create_progress_bar(remaining_pct)
1022
 
1023
+ # Build status text - always show reset time if available
1024
+ has_reset_time = reset_time and reset_time != "-"
1025
+
1026
  # Color based on status
1027
  if is_exhausted:
1028
  color = "red"
1029
+ if has_reset_time:
1030
+ status_text = f"⛔ Resets: {reset_time}"
1031
+ else:
1032
+ status_text = "⛔ EXHAUSTED"
1033
  elif remaining_pct is not None and remaining_pct < 20:
1034
  color = "yellow"
1035
+ if has_reset_time:
1036
+ status_text = f"⚠️ Resets: {reset_time}"
1037
+ else:
1038
+ status_text = "⚠️ LOW"
1039
  else:
1040
  color = "green"
1041
+ if has_reset_time:
1042
+ status_text = f"Resets: {reset_time}"
1043
+ else:
1044
+ status_text = "" # Hide if unused/no reset time
1045
 
1046
  # Confidence indicator
1047
  conf_indicator = ""
src/rotator_library/client.py CHANGED
@@ -2678,9 +2678,18 @@ class RotatingClient:
2678
  tier = provider_instance.project_tier_cache.get(cred_path)
2679
  tier = tier or "unknown"
2680
 
2681
- # Initialize tier entry if needed
2682
  if tier not in group_stats["tiers"]:
2683
- group_stats["tiers"][tier] = {"total": 0, "active": 0}
 
 
 
 
 
 
 
 
 
2684
  group_stats["tiers"][tier]["total"] += 1
2685
 
2686
  # Find model with VALID baseline (not just any model with stats)
@@ -2745,16 +2754,28 @@ class RotatingClient:
2745
 
2746
  for group_name, group_models in quota_groups.items():
2747
  # Find model with VALID baseline (prefer over any model with stats)
 
2748
  model_stats = None
 
 
2749
  for model in group_models:
2750
  candidate = self._find_model_stats_in_data(
2751
  models_data, model, provider, provider_instance
2752
  )
2753
  if candidate:
 
 
 
 
 
 
 
 
 
2754
  baseline = candidate.get("baseline_remaining_fraction")
2755
  if baseline is not None:
2756
  model_stats = candidate
2757
- break
2758
  # Keep first found as fallback
2759
  if model_stats is None:
2760
  model_stats = candidate
@@ -2763,7 +2784,10 @@ class RotatingClient:
2763
  baseline = model_stats.get("baseline_remaining_fraction")
2764
  max_req = model_stats.get("quota_max_requests")
2765
  req_count = model_stats.get("request_count", 0)
2766
- reset_ts = model_stats.get("quota_reset_ts")
 
 
 
2767
 
2768
  remaining_pct = (
2769
  int(baseline * 100) if baseline is not None else None
@@ -2797,6 +2821,25 @@ class RotatingClient:
2797
  ),
2798
  }
2799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2800
  # Try to get email from provider's cache
2801
  cred_path = cred.get("full_path", "")
2802
  if hasattr(provider_instance, "project_tier_cache"):
 
2678
  tier = provider_instance.project_tier_cache.get(cred_path)
2679
  tier = tier or "unknown"
2680
 
2681
+ # Initialize tier entry if needed with priority for sorting
2682
  if tier not in group_stats["tiers"]:
2683
+ priority = 10 # default
2684
+ if hasattr(provider_instance, "_resolve_tier_priority"):
2685
+ priority = provider_instance._resolve_tier_priority(
2686
+ tier
2687
+ )
2688
+ group_stats["tiers"][tier] = {
2689
+ "total": 0,
2690
+ "active": 0,
2691
+ "priority": priority,
2692
+ }
2693
  group_stats["tiers"][tier]["total"] += 1
2694
 
2695
  # Find model with VALID baseline (not just any model with stats)
 
2754
 
2755
  for group_name, group_models in quota_groups.items():
2756
  # Find model with VALID baseline (prefer over any model with stats)
2757
+ # Also track the best reset_ts across all models in the group
2758
  model_stats = None
2759
+ best_reset_ts = None
2760
+
2761
  for model in group_models:
2762
  candidate = self._find_model_stats_in_data(
2763
  models_data, model, provider, provider_instance
2764
  )
2765
  if candidate:
2766
+ # Track the best (latest) reset_ts from any model in group
2767
+ candidate_reset_ts = candidate.get("quota_reset_ts")
2768
+ if candidate_reset_ts:
2769
+ if (
2770
+ best_reset_ts is None
2771
+ or candidate_reset_ts > best_reset_ts
2772
+ ):
2773
+ best_reset_ts = candidate_reset_ts
2774
+
2775
  baseline = candidate.get("baseline_remaining_fraction")
2776
  if baseline is not None:
2777
  model_stats = candidate
2778
+ # Don't break - continue to find best reset_ts
2779
  # Keep first found as fallback
2780
  if model_stats is None:
2781
  model_stats = candidate
 
2784
  baseline = model_stats.get("baseline_remaining_fraction")
2785
  max_req = model_stats.get("quota_max_requests")
2786
  req_count = model_stats.get("request_count", 0)
2787
+ # Use best_reset_ts from any model in the group
2788
+ reset_ts = best_reset_ts or model_stats.get(
2789
+ "quota_reset_ts"
2790
+ )
2791
 
2792
  remaining_pct = (
2793
  int(baseline * 100) if baseline is not None else None
 
2821
  ),
2822
  }
2823
 
2824
+ # Recalculate credential's requests from model_groups
2825
+ # This fixes double-counting when models share quota groups
2826
+ if cred.get("model_groups"):
2827
+ group_requests = sum(
2828
+ g.get("requests_used", 0)
2829
+ for g in cred["model_groups"].values()
2830
+ )
2831
+ cred["requests"] = group_requests
2832
+
2833
+ # HACK: Fix global requests if present
2834
+ # This is a simplified fix that sets global.requests = current group_requests.
2835
+ # TODO: Properly track archived requests per quota group in usage_manager.py
2836
+ # so that global stats correctly sum: current_period + archived_periods
2837
+ # without double-counting models that share quota groups.
2838
+ # See: usage_manager.py lines 2388-2404 where global stats are built
2839
+ # by iterating all models (causing double-counting for grouped models).
2840
+ if cred.get("global"):
2841
+ cred["global"]["requests"] = group_requests
2842
+
2843
  # Try to get email from provider's cache
2844
  cred_path = cred.get("full_path", "")
2845
  if hasattr(provider_instance, "project_tier_cache"):