Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Dict, Any | |
| import numpy as np | |
| import pandas as pd | |
| from conv_data_gen.config import config | |
| from conv_data_gen.logger import setup_logger | |
| from conv_data_gen.llm import LLMClient | |
| logger = setup_logger(__name__) | |
| try: | |
| # google-cloud-aiplatform exposes the vertexai namespace | |
| from vertexai import init as vertexai_init | |
| from vertexai.language_models import TextEmbeddingModel | |
| except Exception: # pragma: no cover - import guard | |
| vertexai_init = None # type: ignore | |
| TextEmbeddingModel = None # type: ignore | |
| class EmbeddingRunResult: | |
| input_count: int | |
| kept_count: int | |
| removed_count: int | |
| threshold: float | |
| model: str | |
| avg_nearest_similarity: float | |
| median_nearest_similarity: float | |
| max_nearest_similarity: float | |
| duplicates_map_path: str | |
| deduped_csv_path: str | |
| embeddings_npy_path: str | |
| report_json_path: str | |
| class UseCaseEmbeddingsDeduper: | |
| """Compute Vertex embeddings and deduplicate similar use cases.""" | |
| def __init__( | |
| self, | |
| project_id: Optional[str] = None, | |
| location: Optional[str] = None, | |
| model_name: str = "text-embedding-004", | |
| batch_size: int = 64, | |
| llm_client: Optional[LLMClient] = None, | |
| ) -> None: | |
| self.project_id = project_id or config.gcp.PROJECT_ID | |
| self.location = location or config.gcp.LOCATION | |
| self.model_name = model_name | |
| self.batch_size = max(1, batch_size) | |
| self.llm_client = llm_client or LLMClient( | |
| project_id=self.project_id, | |
| location=self.location, | |
| ) | |
| def _compose_text_for_embedding(row: pd.Series) -> str: | |
| use_case = str(row.get("use_case", "")).strip() | |
| # logger.info( | |
| # f"use_case: {use_case}, Dedup signature: {dedup_signature}" | |
| # ) | |
| return " | ".join( | |
| [ | |
| p | |
| for p in [ | |
| use_case, | |
| # dedup_signature, | |
| ] | |
| if p | |
| ] | |
| ) | |
| def _embed_texts(self, texts: List[str]) -> np.ndarray: | |
| """Embed texts using shared LLM client method.""" | |
| return self.llm_client.get_text_embeddings( | |
| texts, model_name=self.model_name, batch_size=self.batch_size | |
| ) | |
| def _cosine_normalize(matrix: np.ndarray) -> np.ndarray: | |
| norms = np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-12 | |
| return matrix / norms | |
| def _compute_nearest_similarities( | |
| self, norm_vecs: np.ndarray | |
| ) -> np.ndarray: | |
| # Cosine sim via dot product; set self-sim to -inf to ignore | |
| sims = np.matmul(norm_vecs, norm_vecs.T) | |
| np.fill_diagonal(sims, -np.inf) | |
| nearest = sims.max(axis=1) | |
| return nearest | |
| def _greedy_dedup( | |
| self, norm_vecs: np.ndarray, threshold: float | |
| ) -> Tuple[List[int], Dict[int, Tuple[int, float]]]: | |
| """ | |
| Greedy dedup: keep an item, drop others whose cosine sim >= threshold. | |
| Returns (kept_indices, duplicates_map) where duplicates_map maps | |
| dropped_index -> (kept_index, similarity). | |
| """ | |
| n = norm_vecs.shape[0] | |
| sims = np.matmul(norm_vecs, norm_vecs.T) | |
| kept: List[int] = [] | |
| removed: set[int] = set() | |
| dup_map: Dict[int, Tuple[int, float]] = {} | |
| for i in range(n): | |
| if i in removed: | |
| continue | |
| kept.append(i) | |
| # mark all j similar to i as removed (excluding i) | |
| row = sims[i] | |
| similar_js = np.where(row >= threshold)[0] | |
| for j in similar_js: | |
| if j == i or j in removed: | |
| continue | |
| removed.add(j) | |
| dup_map[j] = (i, float(row[j])) | |
| return kept, dup_map | |
| def run( | |
| self, | |
| input_csv: str, | |
| output_dir: str, | |
| threshold: float = 0.9, | |
| ) -> EmbeddingRunResult: | |
| out_dir = Path(output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| df = pd.read_csv(input_csv) | |
| if df.empty: | |
| raise RuntimeError("Input CSV is empty") | |
| # Build text column for embeddings | |
| df["embedding_text"] = df.apply( | |
| self._compose_text_for_embedding, axis=1 | |
| ) | |
| texts: List[str] = df["embedding_text"].astype(str).tolist() | |
| logger.info("Embedding %d rows using %s", len(texts), self.model_name) | |
| vecs = self._embed_texts(texts) | |
| if vecs.shape[0] != len(texts): | |
| raise RuntimeError("Mismatch between embeddings and input rows") | |
| # Save raw embeddings | |
| embeddings_npy_path = str(out_dir / "use_cases_embeddings.npy") | |
| np.save(embeddings_npy_path, vecs) | |
| # Normalize and compute nearest-neighbor similarities | |
| # (used for variety stats) | |
| norm_vecs = self._cosine_normalize(vecs) | |
| nearest = self._compute_nearest_similarities(norm_vecs) | |
| avg_near = float(np.nanmean(nearest)) | |
| med_near = float(np.nanmedian(nearest)) | |
| max_near = float(np.nanmax(nearest)) | |
| # Deduplicate | |
| kept_indices, dup_map = self._greedy_dedup(norm_vecs, threshold) | |
| deduped_df = df.iloc[kept_indices].copy().reset_index(drop=True) | |
| deduped_csv_path = str(out_dir / "use_cases_deduped.csv") | |
| deduped_df.to_csv(deduped_csv_path, index=False) | |
| # Duplicates map for inspection | |
| dup_rows = [ | |
| { | |
| "original_index": int(j), | |
| "kept_index": int(i), | |
| "similarity": float(sim), | |
| "original_use_case": df.iloc[j].get("use_case", ""), | |
| "kept_use_case": df.iloc[i].get("use_case", ""), | |
| "original_company": df.iloc[j].get("company", ""), | |
| "kept_company": df.iloc[i].get("company", ""), | |
| "original_user_type": df.iloc[j].get("user_type", ""), | |
| "kept_user_type": df.iloc[i].get("user_type", ""), | |
| "original_agent_type": df.iloc[j].get("agent_type", ""), | |
| "kept_agent_type": df.iloc[i].get("agent_type", ""), | |
| } | |
| for j, (i, sim) in sorted(dup_map.items(), key=lambda x: x[0]) | |
| ] | |
| duplicates_map_path = str(out_dir / "duplicates_map.csv") | |
| pd.DataFrame(dup_rows).to_csv(duplicates_map_path, index=False) | |
| # Company-wise statistics | |
| company_stats = {} | |
| for j, (i, sim) in dup_map.items(): | |
| original_company = df.iloc[j].get("company", "Unknown") | |
| kept_company = df.iloc[i].get("company", "Unknown") | |
| if original_company not in company_stats: | |
| company_stats[original_company] = {"removed": 0, "kept": 0} | |
| company_stats[original_company]["removed"] += 1 | |
| if kept_company not in company_stats: | |
| company_stats[kept_company] = {"removed": 0, "kept": 0} | |
| company_stats[kept_company]["kept"] += 1 | |
| # Add company-wise stats to report | |
| company_removal_summary = { | |
| company: { | |
| "removed_count": stats["removed"], | |
| "kept_count": stats["kept"], | |
| "total_original": stats["removed"] + stats["kept"], | |
| } | |
| for company, stats in company_stats.items() | |
| } | |
| # Report | |
| report: Dict[str, Any] = { | |
| "input_count": int(len(df)), | |
| "kept_count": int(len(deduped_df)), | |
| "removed_count": int(len(df) - len(deduped_df)), | |
| "threshold": float(threshold), | |
| "model": self.model_name, | |
| "avg_nearest_similarity": avg_near, | |
| "median_nearest_similarity": med_near, | |
| "max_nearest_similarity": max_near, | |
| "embeddings_npy_path": embeddings_npy_path, | |
| "deduped_csv_path": deduped_csv_path, | |
| "duplicates_map_path": duplicates_map_path, | |
| "company_removal_summary": company_removal_summary, | |
| } | |
| report_json_path = str(out_dir / "dedup_report.json") | |
| with open(report_json_path, "w", encoding="utf-8") as f: | |
| json.dump(report, f, indent=2) | |
| # Log company-wise summary | |
| company_summary = [] | |
| for company, stats in company_removal_summary.items(): | |
| summary = ( | |
| f"{company}: {stats['removed_count']} removed, " | |
| f"{stats['kept_count']} kept" | |
| ) | |
| company_summary.append(summary) | |
| logger.info( | |
| "Dedup complete. Kept %d of %d (removed %d). Avg nearest sim=%.4f", | |
| report["kept_count"], | |
| report["input_count"], | |
| report["removed_count"], | |
| report["avg_nearest_similarity"], | |
| ) | |
| logger.info("Company-wise summary: %s", "; ".join(company_summary)) | |
| return EmbeddingRunResult( | |
| input_count=report["input_count"], | |
| kept_count=report["kept_count"], | |
| removed_count=report["removed_count"], | |
| threshold=threshold, | |
| model=self.model_name, | |
| avg_nearest_similarity=avg_near, | |
| median_nearest_similarity=med_near, | |
| max_nearest_similarity=max_near, | |
| duplicates_map_path=duplicates_map_path, | |
| deduped_csv_path=deduped_csv_path, | |
| embeddings_npy_path=embeddings_npy_path, | |
| report_json_path=report_json_path, | |
| ) | |