import os import numpy as np from sklearn.cluster import KMeans from typing import List, Dict, Any from google import genai import google.genai.types as types from supabase import create_client, Client from dotenv import load_dotenv # Load environment variables load_dotenv() class ClusteringService: def __init__(self): url = os.environ.get("SUPABASE_URL") key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") or os.environ.get("SUPABASE_KEY") self.client: Client = create_client(url, key) self.gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) def fetch_all_embeddings(self) -> List[Dict[str, Any]]: """Fetch IDs and concatenated embeddings for all profiles.""" print("🔍 Fetching profile embeddings...") # We'll use 'technical_skills' or 'headline' as a representative embedding for clustering # Or concatenate multiple if available. For simplicity, we use 'technical_skills' resp = self.client.table("profile_embeddings").select("id, technical_skills").execute() return resp.data def perform_clustering(self, data: List[Dict[str, Any]], n_clusters: int = 5): """Perform K-Means clustering on the fetched embeddings.""" if not data: print("⚠️ No data to cluster.") return [] # Extract vectors X = [] ids = [] import json for item in data: raw_vec = item.get("technical_skills") if raw_vec: try: # If it's a string, parse it if isinstance(raw_vec, str): # Some versions of postgrest return vectors as strings like '[0.1, 0.2]' vec = json.loads(raw_vec) else: vec = raw_vec X.append(vec) ids.append(item["id"]) except Exception as e: print(f"⚠️ Failed to parse embedding for {item['id']}: {e}") if len(X) < n_clusters: n_clusters = max(1, len(X)) print(f"🤖 Performing K-Means clustering (K={n_clusters})...") kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) labels = kmeans.fit_predict(X) return [{"id": ids[i], "cluster": int(labels[i])} for i in range(len(ids))] def generate_labels_for_clusters(self, clustered_data: List[Dict[str, Any]]) -> Dict[int, str]: """Generate human-readable labels for each cluster using Gemini.""" cluster_groups = {} for item in clustered_data: c = item["cluster"] if c not in cluster_groups: cluster_groups[c] = [] cluster_groups[c].append(item["id"]) labels = {} for cluster_id, user_ids in cluster_groups.items(): # Fetch sample details for these users to describe the cluster sample_ids = user_ids[:5] profiles_resp = self.client.table("profiles").select("headline, technical_skills").in_("id", sample_ids).execute() sample_text = "\n".join([ f"- {p.get('headline')} (Skills: {p.get('technical_skills')})" for p in profiles_resp.data ]) prompt = f""" You are an expert HR Talent Acquisition Specialist. Analyze the following representative professional profiles from a talent pool and provide a perfect, professional job title that best encapsulates the entire group. CRITERIA: - Concise: Exactly 2-4 words. - Professional: Use industry-standard terminology (e.g., "Full Stack Engineer", "DevOps Architect"). - Accurate: Reflect the common denominator in seniority and technical domain. - Formatting: Return ONLY the title string, no quotes, no extra text. REPRESENTATIVE PROFILES: {sample_text} PERFECT JOB TITLE: """ import time max_retries = 3 label = "Unknown Group" for attempt in range(max_retries): try: response = self.gemini_client.models.generate_content( model="gemini-2.5-flash-lite", contents=prompt, config=types.GenerateContentConfig(temperature=0) ) label = response.text.strip().replace('"', '') break except Exception as e: if attempt < max_retries - 1: wait = 2 ** (attempt + 1) print(f"⚠️ Labeling failed for Cluster {cluster_id}. Retrying in {wait}s... ({e})") time.sleep(wait) else: print(f"❌ Labeling failed for Cluster {cluster_id} after {max_retries} attempts.") labels[cluster_id] = label print(f"✅ Cluster {cluster_id} Label: {label}") time.sleep(1) # Small pause between clusters return labels def update_database_with_labels(self, clustered_data: List[Dict[str, Any]], cluster_labels: Dict[int, str]): """Update the profiles table with the new cluster labels.""" print("💾 Updating database with cluster labels...") for item in clustered_data: user_id = item["id"] label = cluster_labels[item["cluster"]] self.client.table("profiles").update({"cluster_label": label}).eq("id", user_id).execute() print("✨ Database successfully updated.") def run_clustering_pipeline(self, n_clusters: int = 5): """Orchestrate the full clustering pipeline.""" data = self.fetch_all_embeddings() clustered_results = self.perform_clustering(data, n_clusters) if not clustered_results: return labels = self.generate_labels_for_clusters(clustered_results) self.update_database_with_labels(clustered_results, labels) if __name__ == "__main__": service = ClusteringService() service.run_clustering_pipeline(n_clusters=5)