data-gen / conv_data_gen /dedup /proxy_dedup_jsonl.py
ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np # type: ignore[import-untyped]
from conv_data_gen.llm import LLMClient
from conv_data_gen.logger import setup_logger
logger = setup_logger(__name__)
def _normalize_goal(text: str) -> str:
t = (text or "").strip().lower()
return " ".join(t.split()) or "__empty__"
def dedup_proxies_artifact(
proxies_jsonl_path: str,
output_jsonl_path: Optional[str] = None,
*,
similarity_threshold: float = 0.90,
embedding_model: str = "gemini-embedding-001",
batch_size: int = 64,
llm_client: Optional[LLMClient] = None,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
p = Path(proxies_jsonl_path)
rows: List[Dict[str, Any]] = []
with open(p, "r", encoding="utf-8") as f:
for line in f:
try:
obj = json.loads(line)
if isinstance(obj, dict):
rows.append({str(k): v for k, v in obj.items()})
except Exception:
continue
input_count = len(rows)
# Split goals for semantic dedup (allow at most one empty goal)
goals: List[str] = []
idx_map: List[int] = []
empty_goal_indices: List[int] = []
for idx, r in enumerate(rows):
raw_proxy = r.get("proxy")
proxy_dict = raw_proxy if isinstance(raw_proxy, dict) else {}
goal_text = str(proxy_dict.get("goal", "") or "").strip()
if goal_text:
goals.append(goal_text)
idx_map.append(idx)
else:
empty_goal_indices.append(idx)
# If <=1 non-empty, fall back to normalized text-based dedup
if len(goals) <= 1:
kept_indices: List[int] = []
# keep at most one empty
if empty_goal_indices:
kept_indices.append(empty_goal_indices[0])
# include any single non-empty
if idx_map:
kept_indices.append(idx_map[0])
deduped = [rows[i] for i in sorted(set(kept_indices))]
kept_count = len(deduped)
removed_count = input_count - kept_count
metrics: Dict[str, Any] = {
"input_count": input_count,
"kept_count": kept_count,
"removed_count": removed_count,
}
if output_jsonl_path:
op = Path(output_jsonl_path)
op.parent.mkdir(parents=True, exist_ok=True)
with open(op, "w", encoding="utf-8") as f:
for r in deduped:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
logger.info(
(
"[ProxyDedup] Wrote deduped proxies: %d kept (removed %d)"
" to %s"
),
kept_count,
removed_count,
str(op),
)
return deduped, metrics
# Semantic dedup via embeddings (like use-cases)
try:
client = llm_client or LLMClient()
vecs = client.get_text_embeddings(
goals, model_name=embedding_model, batch_size=batch_size
)
if vecs.shape[0] != len(goals):
raise RuntimeError("Embeddings count mismatch")
norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
norm_vecs = vecs / norms
sims = np.matmul(norm_vecs, norm_vecs.T)
# Greedy keep-first with similarity threshold
kept_local: List[int] = []
removed_local: set[int] = set()
n = sims.shape[0]
for i in range(n):
if i in removed_local:
continue
kept_local.append(i)
sim_row = sims[i]
similar_js = np.where(sim_row >= similarity_threshold)[0]
for j in similar_js:
if j == i or j in removed_local:
continue
removed_local.add(j)
kept_global_indices: List[int] = []
if empty_goal_indices:
kept_global_indices.append(empty_goal_indices[0])
for local_idx in kept_local:
kept_global_indices.append(idx_map[local_idx])
kept_global_indices = sorted(set(kept_global_indices))
deduped = [rows[i] for i in kept_global_indices]
# Average nearest-neighbour similarity (non-empty only)
if n > 1:
sims_wo_diag = sims - np.eye(n)
nn = np.max(sims_wo_diag, axis=1)
avg_nearest = float(np.mean(nn))
else:
avg_nearest = 0.0
kept_count = len(deduped)
removed_count = input_count - kept_count
metrics = {
"input_count": input_count,
"kept_count": kept_count,
"removed_count": removed_count,
"avg_nearest_similarity": avg_nearest,
"threshold": similarity_threshold,
}
except Exception as exc: # pragma: no cover
logger.error("[ProxyDedup] semantic dedup failed, fallback: %s", exc)
# Fallback to normalized text dedup
seen_keys: set[str] = set()
deduped = []
for r in rows:
raw_proxy = r.get("proxy")
proxy_dict = raw_proxy if isinstance(raw_proxy, dict) else {}
goal_text = str(proxy_dict.get("goal", "") or "")
key = _normalize_goal(goal_text)
if key in seen_keys:
continue
seen_keys.add(key)
deduped.append(r)
kept_count = len(deduped)
removed_count = input_count - kept_count
metrics = {
"input_count": input_count,
"kept_count": kept_count,
"removed_count": removed_count,
}
if output_jsonl_path:
op = Path(output_jsonl_path)
op.parent.mkdir(parents=True, exist_ok=True)
with open(op, "w", encoding="utf-8") as f:
for r in deduped:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
logger.info(
"[ProxyDedup] Wrote deduped proxies: %d kept (removed %d) to %s",
kept_count,
removed_count,
str(op),
)
return deduped, metrics