data-gen / conv_data_gen /dedup /use_case_dedup.py
ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any
import numpy as np
import pandas as pd
from conv_data_gen.config import config
from conv_data_gen.logger import setup_logger
from conv_data_gen.llm import LLMClient
logger = setup_logger(__name__)
try:
# google-cloud-aiplatform exposes the vertexai namespace
from vertexai import init as vertexai_init
from vertexai.language_models import TextEmbeddingModel
except Exception: # pragma: no cover - import guard
vertexai_init = None # type: ignore
TextEmbeddingModel = None # type: ignore
@dataclass
class EmbeddingRunResult:
input_count: int
kept_count: int
removed_count: int
threshold: float
model: str
avg_nearest_similarity: float
median_nearest_similarity: float
max_nearest_similarity: float
duplicates_map_path: str
deduped_csv_path: str
embeddings_npy_path: str
report_json_path: str
class UseCaseEmbeddingsDeduper:
"""Compute Vertex embeddings and deduplicate similar use cases."""
def __init__(
self,
project_id: Optional[str] = None,
location: Optional[str] = None,
model_name: str = "text-embedding-004",
batch_size: int = 64,
llm_client: Optional[LLMClient] = None,
) -> None:
self.project_id = project_id or config.gcp.PROJECT_ID
self.location = location or config.gcp.LOCATION
self.model_name = model_name
self.batch_size = max(1, batch_size)
self.llm_client = llm_client or LLMClient(
project_id=self.project_id,
location=self.location,
)
@staticmethod
def _compose_text_for_embedding(row: pd.Series) -> str:
use_case = str(row.get("use_case", "")).strip()
# logger.info(
# f"use_case: {use_case}, Dedup signature: {dedup_signature}"
# )
return " | ".join(
[
p
for p in [
use_case,
# dedup_signature,
]
if p
]
)
def _embed_texts(self, texts: List[str]) -> np.ndarray:
"""Embed texts using shared LLM client method."""
return self.llm_client.get_text_embeddings(
texts, model_name=self.model_name, batch_size=self.batch_size
)
@staticmethod
def _cosine_normalize(matrix: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-12
return matrix / norms
def _compute_nearest_similarities(
self, norm_vecs: np.ndarray
) -> np.ndarray:
# Cosine sim via dot product; set self-sim to -inf to ignore
sims = np.matmul(norm_vecs, norm_vecs.T)
np.fill_diagonal(sims, -np.inf)
nearest = sims.max(axis=1)
return nearest
def _greedy_dedup(
self, norm_vecs: np.ndarray, threshold: float
) -> Tuple[List[int], Dict[int, Tuple[int, float]]]:
"""
Greedy dedup: keep an item, drop others whose cosine sim >= threshold.
Returns (kept_indices, duplicates_map) where duplicates_map maps
dropped_index -> (kept_index, similarity).
"""
n = norm_vecs.shape[0]
sims = np.matmul(norm_vecs, norm_vecs.T)
kept: List[int] = []
removed: set[int] = set()
dup_map: Dict[int, Tuple[int, float]] = {}
for i in range(n):
if i in removed:
continue
kept.append(i)
# mark all j similar to i as removed (excluding i)
row = sims[i]
similar_js = np.where(row >= threshold)[0]
for j in similar_js:
if j == i or j in removed:
continue
removed.add(j)
dup_map[j] = (i, float(row[j]))
return kept, dup_map
def run(
self,
input_csv: str,
output_dir: str,
threshold: float = 0.9,
) -> EmbeddingRunResult:
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_csv(input_csv)
if df.empty:
raise RuntimeError("Input CSV is empty")
# Build text column for embeddings
df["embedding_text"] = df.apply(
self._compose_text_for_embedding, axis=1
)
texts: List[str] = df["embedding_text"].astype(str).tolist()
logger.info("Embedding %d rows using %s", len(texts), self.model_name)
vecs = self._embed_texts(texts)
if vecs.shape[0] != len(texts):
raise RuntimeError("Mismatch between embeddings and input rows")
# Save raw embeddings
embeddings_npy_path = str(out_dir / "use_cases_embeddings.npy")
np.save(embeddings_npy_path, vecs)
# Normalize and compute nearest-neighbor similarities
# (used for variety stats)
norm_vecs = self._cosine_normalize(vecs)
nearest = self._compute_nearest_similarities(norm_vecs)
avg_near = float(np.nanmean(nearest))
med_near = float(np.nanmedian(nearest))
max_near = float(np.nanmax(nearest))
# Deduplicate
kept_indices, dup_map = self._greedy_dedup(norm_vecs, threshold)
deduped_df = df.iloc[kept_indices].copy().reset_index(drop=True)
deduped_csv_path = str(out_dir / "use_cases_deduped.csv")
deduped_df.to_csv(deduped_csv_path, index=False)
# Duplicates map for inspection
dup_rows = [
{
"original_index": int(j),
"kept_index": int(i),
"similarity": float(sim),
"original_use_case": df.iloc[j].get("use_case", ""),
"kept_use_case": df.iloc[i].get("use_case", ""),
"original_company": df.iloc[j].get("company", ""),
"kept_company": df.iloc[i].get("company", ""),
"original_user_type": df.iloc[j].get("user_type", ""),
"kept_user_type": df.iloc[i].get("user_type", ""),
"original_agent_type": df.iloc[j].get("agent_type", ""),
"kept_agent_type": df.iloc[i].get("agent_type", ""),
}
for j, (i, sim) in sorted(dup_map.items(), key=lambda x: x[0])
]
duplicates_map_path = str(out_dir / "duplicates_map.csv")
pd.DataFrame(dup_rows).to_csv(duplicates_map_path, index=False)
# Company-wise statistics
company_stats = {}
for j, (i, sim) in dup_map.items():
original_company = df.iloc[j].get("company", "Unknown")
kept_company = df.iloc[i].get("company", "Unknown")
if original_company not in company_stats:
company_stats[original_company] = {"removed": 0, "kept": 0}
company_stats[original_company]["removed"] += 1
if kept_company not in company_stats:
company_stats[kept_company] = {"removed": 0, "kept": 0}
company_stats[kept_company]["kept"] += 1
# Add company-wise stats to report
company_removal_summary = {
company: {
"removed_count": stats["removed"],
"kept_count": stats["kept"],
"total_original": stats["removed"] + stats["kept"],
}
for company, stats in company_stats.items()
}
# Report
report: Dict[str, Any] = {
"input_count": int(len(df)),
"kept_count": int(len(deduped_df)),
"removed_count": int(len(df) - len(deduped_df)),
"threshold": float(threshold),
"model": self.model_name,
"avg_nearest_similarity": avg_near,
"median_nearest_similarity": med_near,
"max_nearest_similarity": max_near,
"embeddings_npy_path": embeddings_npy_path,
"deduped_csv_path": deduped_csv_path,
"duplicates_map_path": duplicates_map_path,
"company_removal_summary": company_removal_summary,
}
report_json_path = str(out_dir / "dedup_report.json")
with open(report_json_path, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
# Log company-wise summary
company_summary = []
for company, stats in company_removal_summary.items():
summary = (
f"{company}: {stats['removed_count']} removed, "
f"{stats['kept_count']} kept"
)
company_summary.append(summary)
logger.info(
"Dedup complete. Kept %d of %d (removed %d). Avg nearest sim=%.4f",
report["kept_count"],
report["input_count"],
report["removed_count"],
report["avg_nearest_similarity"],
)
logger.info("Company-wise summary: %s", "; ".join(company_summary))
return EmbeddingRunResult(
input_count=report["input_count"],
kept_count=report["kept_count"],
removed_count=report["removed_count"],
threshold=threshold,
model=self.model_name,
avg_nearest_similarity=avg_near,
median_nearest_similarity=med_near,
max_nearest_similarity=max_near,
duplicates_map_path=duplicates_map_path,
deduped_csv_path=deduped_csv_path,
embeddings_npy_path=embeddings_npy_path,
report_json_path=report_json_path,
)