from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np # type: ignore[import-untyped] from conv_data_gen.llm import LLMClient from conv_data_gen.logger import setup_logger logger = setup_logger(__name__) def _normalize_goal(text: str) -> str: t = (text or "").strip().lower() return " ".join(t.split()) or "__empty__" def dedup_proxies_artifact( proxies_jsonl_path: str, output_jsonl_path: Optional[str] = None, *, similarity_threshold: float = 0.90, embedding_model: str = "gemini-embedding-001", batch_size: int = 64, llm_client: Optional[LLMClient] = None, ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: p = Path(proxies_jsonl_path) rows: List[Dict[str, Any]] = [] with open(p, "r", encoding="utf-8") as f: for line in f: try: obj = json.loads(line) if isinstance(obj, dict): rows.append({str(k): v for k, v in obj.items()}) except Exception: continue input_count = len(rows) # Split goals for semantic dedup (allow at most one empty goal) goals: List[str] = [] idx_map: List[int] = [] empty_goal_indices: List[int] = [] for idx, r in enumerate(rows): raw_proxy = r.get("proxy") proxy_dict = raw_proxy if isinstance(raw_proxy, dict) else {} goal_text = str(proxy_dict.get("goal", "") or "").strip() if goal_text: goals.append(goal_text) idx_map.append(idx) else: empty_goal_indices.append(idx) # If <=1 non-empty, fall back to normalized text-based dedup if len(goals) <= 1: kept_indices: List[int] = [] # keep at most one empty if empty_goal_indices: kept_indices.append(empty_goal_indices[0]) # include any single non-empty if idx_map: kept_indices.append(idx_map[0]) deduped = [rows[i] for i in sorted(set(kept_indices))] kept_count = len(deduped) removed_count = input_count - kept_count metrics: Dict[str, Any] = { "input_count": input_count, "kept_count": kept_count, "removed_count": removed_count, } if output_jsonl_path: op = Path(output_jsonl_path) op.parent.mkdir(parents=True, exist_ok=True) with open(op, "w", encoding="utf-8") as f: for r in deduped: f.write(json.dumps(r, ensure_ascii=False) + "\n") logger.info( ( "[ProxyDedup] Wrote deduped proxies: %d kept (removed %d)" " to %s" ), kept_count, removed_count, str(op), ) return deduped, metrics # Semantic dedup via embeddings (like use-cases) try: client = llm_client or LLMClient() vecs = client.get_text_embeddings( goals, model_name=embedding_model, batch_size=batch_size ) if vecs.shape[0] != len(goals): raise RuntimeError("Embeddings count mismatch") norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 norm_vecs = vecs / norms sims = np.matmul(norm_vecs, norm_vecs.T) # Greedy keep-first with similarity threshold kept_local: List[int] = [] removed_local: set[int] = set() n = sims.shape[0] for i in range(n): if i in removed_local: continue kept_local.append(i) sim_row = sims[i] similar_js = np.where(sim_row >= similarity_threshold)[0] for j in similar_js: if j == i or j in removed_local: continue removed_local.add(j) kept_global_indices: List[int] = [] if empty_goal_indices: kept_global_indices.append(empty_goal_indices[0]) for local_idx in kept_local: kept_global_indices.append(idx_map[local_idx]) kept_global_indices = sorted(set(kept_global_indices)) deduped = [rows[i] for i in kept_global_indices] # Average nearest-neighbour similarity (non-empty only) if n > 1: sims_wo_diag = sims - np.eye(n) nn = np.max(sims_wo_diag, axis=1) avg_nearest = float(np.mean(nn)) else: avg_nearest = 0.0 kept_count = len(deduped) removed_count = input_count - kept_count metrics = { "input_count": input_count, "kept_count": kept_count, "removed_count": removed_count, "avg_nearest_similarity": avg_nearest, "threshold": similarity_threshold, } except Exception as exc: # pragma: no cover logger.error("[ProxyDedup] semantic dedup failed, fallback: %s", exc) # Fallback to normalized text dedup seen_keys: set[str] = set() deduped = [] for r in rows: raw_proxy = r.get("proxy") proxy_dict = raw_proxy if isinstance(raw_proxy, dict) else {} goal_text = str(proxy_dict.get("goal", "") or "") key = _normalize_goal(goal_text) if key in seen_keys: continue seen_keys.add(key) deduped.append(r) kept_count = len(deduped) removed_count = input_count - kept_count metrics = { "input_count": input_count, "kept_count": kept_count, "removed_count": removed_count, } if output_jsonl_path: op = Path(output_jsonl_path) op.parent.mkdir(parents=True, exist_ok=True) with open(op, "w", encoding="utf-8") as f: for r in deduped: f.write(json.dumps(r, ensure_ascii=False) + "\n") logger.info( "[ProxyDedup] Wrote deduped proxies: %d kept (removed %d) to %s", kept_count, removed_count, str(op), ) return deduped, metrics