Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np # type: ignore[import-untyped] | |
| from conv_data_gen.llm import LLMClient | |
| from conv_data_gen.logger import setup_logger | |
| logger = setup_logger(__name__) | |
| def _normalize_goal(text: str) -> str: | |
| t = (text or "").strip().lower() | |
| return " ".join(t.split()) or "__empty__" | |
| def dedup_proxies_artifact( | |
| proxies_jsonl_path: str, | |
| output_jsonl_path: Optional[str] = None, | |
| *, | |
| similarity_threshold: float = 0.90, | |
| embedding_model: str = "gemini-embedding-001", | |
| batch_size: int = 64, | |
| llm_client: Optional[LLMClient] = None, | |
| ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
| p = Path(proxies_jsonl_path) | |
| rows: List[Dict[str, Any]] = [] | |
| with open(p, "r", encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| obj = json.loads(line) | |
| if isinstance(obj, dict): | |
| rows.append({str(k): v for k, v in obj.items()}) | |
| except Exception: | |
| continue | |
| input_count = len(rows) | |
| # Split goals for semantic dedup (allow at most one empty goal) | |
| goals: List[str] = [] | |
| idx_map: List[int] = [] | |
| empty_goal_indices: List[int] = [] | |
| for idx, r in enumerate(rows): | |
| raw_proxy = r.get("proxy") | |
| proxy_dict = raw_proxy if isinstance(raw_proxy, dict) else {} | |
| goal_text = str(proxy_dict.get("goal", "") or "").strip() | |
| if goal_text: | |
| goals.append(goal_text) | |
| idx_map.append(idx) | |
| else: | |
| empty_goal_indices.append(idx) | |
| # If <=1 non-empty, fall back to normalized text-based dedup | |
| if len(goals) <= 1: | |
| kept_indices: List[int] = [] | |
| # keep at most one empty | |
| if empty_goal_indices: | |
| kept_indices.append(empty_goal_indices[0]) | |
| # include any single non-empty | |
| if idx_map: | |
| kept_indices.append(idx_map[0]) | |
| deduped = [rows[i] for i in sorted(set(kept_indices))] | |
| kept_count = len(deduped) | |
| removed_count = input_count - kept_count | |
| metrics: Dict[str, Any] = { | |
| "input_count": input_count, | |
| "kept_count": kept_count, | |
| "removed_count": removed_count, | |
| } | |
| if output_jsonl_path: | |
| op = Path(output_jsonl_path) | |
| op.parent.mkdir(parents=True, exist_ok=True) | |
| with open(op, "w", encoding="utf-8") as f: | |
| for r in deduped: | |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") | |
| logger.info( | |
| ( | |
| "[ProxyDedup] Wrote deduped proxies: %d kept (removed %d)" | |
| " to %s" | |
| ), | |
| kept_count, | |
| removed_count, | |
| str(op), | |
| ) | |
| return deduped, metrics | |
| # Semantic dedup via embeddings (like use-cases) | |
| try: | |
| client = llm_client or LLMClient() | |
| vecs = client.get_text_embeddings( | |
| goals, model_name=embedding_model, batch_size=batch_size | |
| ) | |
| if vecs.shape[0] != len(goals): | |
| raise RuntimeError("Embeddings count mismatch") | |
| norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 | |
| norm_vecs = vecs / norms | |
| sims = np.matmul(norm_vecs, norm_vecs.T) | |
| # Greedy keep-first with similarity threshold | |
| kept_local: List[int] = [] | |
| removed_local: set[int] = set() | |
| n = sims.shape[0] | |
| for i in range(n): | |
| if i in removed_local: | |
| continue | |
| kept_local.append(i) | |
| sim_row = sims[i] | |
| similar_js = np.where(sim_row >= similarity_threshold)[0] | |
| for j in similar_js: | |
| if j == i or j in removed_local: | |
| continue | |
| removed_local.add(j) | |
| kept_global_indices: List[int] = [] | |
| if empty_goal_indices: | |
| kept_global_indices.append(empty_goal_indices[0]) | |
| for local_idx in kept_local: | |
| kept_global_indices.append(idx_map[local_idx]) | |
| kept_global_indices = sorted(set(kept_global_indices)) | |
| deduped = [rows[i] for i in kept_global_indices] | |
| # Average nearest-neighbour similarity (non-empty only) | |
| if n > 1: | |
| sims_wo_diag = sims - np.eye(n) | |
| nn = np.max(sims_wo_diag, axis=1) | |
| avg_nearest = float(np.mean(nn)) | |
| else: | |
| avg_nearest = 0.0 | |
| kept_count = len(deduped) | |
| removed_count = input_count - kept_count | |
| metrics = { | |
| "input_count": input_count, | |
| "kept_count": kept_count, | |
| "removed_count": removed_count, | |
| "avg_nearest_similarity": avg_nearest, | |
| "threshold": similarity_threshold, | |
| } | |
| except Exception as exc: # pragma: no cover | |
| logger.error("[ProxyDedup] semantic dedup failed, fallback: %s", exc) | |
| # Fallback to normalized text dedup | |
| seen_keys: set[str] = set() | |
| deduped = [] | |
| for r in rows: | |
| raw_proxy = r.get("proxy") | |
| proxy_dict = raw_proxy if isinstance(raw_proxy, dict) else {} | |
| goal_text = str(proxy_dict.get("goal", "") or "") | |
| key = _normalize_goal(goal_text) | |
| if key in seen_keys: | |
| continue | |
| seen_keys.add(key) | |
| deduped.append(r) | |
| kept_count = len(deduped) | |
| removed_count = input_count - kept_count | |
| metrics = { | |
| "input_count": input_count, | |
| "kept_count": kept_count, | |
| "removed_count": removed_count, | |
| } | |
| if output_jsonl_path: | |
| op = Path(output_jsonl_path) | |
| op.parent.mkdir(parents=True, exist_ok=True) | |
| with open(op, "w", encoding="utf-8") as f: | |
| for r in deduped: | |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") | |
| logger.info( | |
| "[ProxyDedup] Wrote deduped proxies: %d kept (removed %d) to %s", | |
| kept_count, | |
| removed_count, | |
| str(op), | |
| ) | |
| return deduped, metrics | |