Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from sklearn.cluster import KMeans | |
| from typing import List, Dict, Any | |
| from google import genai | |
| import google.genai.types as types | |
| from supabase import create_client, Client | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| class ClusteringService: | |
| def __init__(self): | |
| url = os.environ.get("SUPABASE_URL") | |
| key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") or os.environ.get("SUPABASE_KEY") | |
| self.client: Client = create_client(url, key) | |
| self.gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
| def fetch_all_embeddings(self) -> List[Dict[str, Any]]: | |
| """Fetch IDs and concatenated embeddings for all profiles.""" | |
| print("🔍 Fetching profile embeddings...") | |
| # We'll use 'technical_skills' or 'headline' as a representative embedding for clustering | |
| # Or concatenate multiple if available. For simplicity, we use 'technical_skills' | |
| resp = self.client.table("profile_embeddings").select("id, technical_skills").execute() | |
| return resp.data | |
| def perform_clustering(self, data: List[Dict[str, Any]], n_clusters: int = 5): | |
| """Perform K-Means clustering on the fetched embeddings.""" | |
| if not data: | |
| print("⚠️ No data to cluster.") | |
| return [] | |
| # Extract vectors | |
| X = [] | |
| ids = [] | |
| import json | |
| for item in data: | |
| raw_vec = item.get("technical_skills") | |
| if raw_vec: | |
| try: | |
| # If it's a string, parse it | |
| if isinstance(raw_vec, str): | |
| # Some versions of postgrest return vectors as strings like '[0.1, 0.2]' | |
| vec = json.loads(raw_vec) | |
| else: | |
| vec = raw_vec | |
| X.append(vec) | |
| ids.append(item["id"]) | |
| except Exception as e: | |
| print(f"⚠️ Failed to parse embedding for {item['id']}: {e}") | |
| if len(X) < n_clusters: | |
| n_clusters = max(1, len(X)) | |
| print(f"🤖 Performing K-Means clustering (K={n_clusters})...") | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) | |
| labels = kmeans.fit_predict(X) | |
| return [{"id": ids[i], "cluster": int(labels[i])} for i in range(len(ids))] | |
| def generate_labels_for_clusters(self, clustered_data: List[Dict[str, Any]]) -> Dict[int, str]: | |
| """Generate human-readable labels for each cluster using Gemini.""" | |
| cluster_groups = {} | |
| for item in clustered_data: | |
| c = item["cluster"] | |
| if c not in cluster_groups: | |
| cluster_groups[c] = [] | |
| cluster_groups[c].append(item["id"]) | |
| labels = {} | |
| for cluster_id, user_ids in cluster_groups.items(): | |
| # Fetch sample details for these users to describe the cluster | |
| sample_ids = user_ids[:5] | |
| profiles_resp = self.client.table("profiles").select("headline, technical_skills").in_("id", sample_ids).execute() | |
| sample_text = "\n".join([ | |
| f"- {p.get('headline')} (Skills: {p.get('technical_skills')})" | |
| for p in profiles_resp.data | |
| ]) | |
| prompt = f""" | |
| You are an expert HR Talent Acquisition Specialist. | |
| Analyze the following representative professional profiles from a talent pool and provide a perfect, professional job title that best encapsulates the entire group. | |
| CRITERIA: | |
| - Concise: Exactly 2-4 words. | |
| - Professional: Use industry-standard terminology (e.g., "Full Stack Engineer", "DevOps Architect"). | |
| - Accurate: Reflect the common denominator in seniority and technical domain. | |
| - Formatting: Return ONLY the title string, no quotes, no extra text. | |
| REPRESENTATIVE PROFILES: | |
| {sample_text} | |
| PERFECT JOB TITLE: | |
| """ | |
| import time | |
| max_retries = 3 | |
| label = "Unknown Group" | |
| for attempt in range(max_retries): | |
| try: | |
| response = self.gemini_client.models.generate_content( | |
| model="gemini-2.5-flash-lite", | |
| contents=prompt, | |
| config=types.GenerateContentConfig(temperature=0) | |
| ) | |
| label = response.text.strip().replace('"', '') | |
| break | |
| except Exception as e: | |
| if attempt < max_retries - 1: | |
| wait = 2 ** (attempt + 1) | |
| print(f"⚠️ Labeling failed for Cluster {cluster_id}. Retrying in {wait}s... ({e})") | |
| time.sleep(wait) | |
| else: | |
| print(f"❌ Labeling failed for Cluster {cluster_id} after {max_retries} attempts.") | |
| labels[cluster_id] = label | |
| print(f"✅ Cluster {cluster_id} Label: {label}") | |
| time.sleep(1) # Small pause between clusters | |
| return labels | |
| def update_database_with_labels(self, clustered_data: List[Dict[str, Any]], cluster_labels: Dict[int, str]): | |
| """Update the profiles table with the new cluster labels.""" | |
| print("💾 Updating database with cluster labels...") | |
| for item in clustered_data: | |
| user_id = item["id"] | |
| label = cluster_labels[item["cluster"]] | |
| self.client.table("profiles").update({"cluster_label": label}).eq("id", user_id).execute() | |
| print("✨ Database successfully updated.") | |
| def run_clustering_pipeline(self, n_clusters: int = 5): | |
| """Orchestrate the full clustering pipeline.""" | |
| data = self.fetch_all_embeddings() | |
| clustered_results = self.perform_clustering(data, n_clusters) | |
| if not clustered_results: | |
| return | |
| labels = self.generate_labels_for_clusters(clustered_results) | |
| self.update_database_with_labels(clustered_results, labels) | |
| if __name__ == "__main__": | |
| service = ClusteringService() | |
| service.run_clustering_pipeline(n_clusters=5) | |