import pandas as pd import torch import re from collections import Counter from sentence_transformers import util from tqdm import tqdm from src.utils import ( clean_degree_text, normalize_text, strip_degrees_for_search, smart_format ) from src.config import TOP_K_CANDIDATES, DEFAULT_SIMILARITY_THRESHOLD def self_cluster_degrees(raw_degrees_list, model, school_cache, threshold=0.93): """Cluster similar degree labels inside one institution.""" cleaned_list = [clean_degree_text(raw) for raw in raw_degrees_list if isinstance(raw, str)] raw_to_clean = {raw: clean_degree_text(raw) for raw in raw_degrees_list if isinstance(raw, str)} clean_counts = Counter(cleaned_list) unique_cleans = [deg for deg, count in clean_counts.most_common() if deg] raw_to_meta = {} if len(unique_cleans) <= 1: for raw, clean in raw_to_clean.items(): raw_to_meta[raw] = (clean, "Degree_Formatter", "HIGH") return raw_to_meta embeddings = model.encode(unique_cleans, convert_to_tensor=True) clean_to_clustered = {} merge_info = {} # Track similarity scores for Blueprint transparency. for i, current_deg in enumerate(unique_cleans): if current_deg in clean_to_clustered: continue clean_to_clustered[current_deg] = current_deg if i + 1 < len(unique_cleans): cos_scores = util.cos_sim(embeddings[i], embeddings[i+1:])[0] for j, score in enumerate(cos_scores): target_deg = unique_cleans[i + 1 + j] if score.item() >= threshold and target_deg not in clean_to_clustered: pair_key = f"{min(current_deg, target_deg)}|||{max(current_deg, target_deg)}" # Runtime cache avoids repeated decisions within one run only. cached_action = school_cache.get(pair_key) if cached_action: if cached_action == current_deg: clean_to_clustered[target_deg] = current_deg elif cached_action == target_deg: clean_to_clustered[current_deg] = target_deg merge_info[target_deg] = "Cached (Runtime)" else: clean_to_clustered[target_deg] = current_deg school_cache[pair_key] = current_deg merge_info[target_deg] = f"{score.item()*100:.1f}%" for raw, clean in raw_to_clean.items(): final_val = clean was_merged = False sim_str = "" while final_val in clean_to_clustered and clean_to_clustered[final_val] != final_val: if not sim_str: sim_str = merge_info.get(final_val, "") final_val = clean_to_clustered[final_val] was_merged = True conf = "MEDIUM" if was_merged else "HIGH" if was_merged: src = f"Auto-Merge ({sim_str})" if sim_str else "Auto-Merge" else: src = "Degree_Formatter" raw_to_meta[raw] = (final_val, src, conf) return raw_to_meta def cluster_degrees_by_institution(df, degree_col, inst_col, model, master_cache, blueprint_data, threshold=0.93): """Apply degree clustering separately for each institution.""" print(f"\n[INFO] Auto-Clustering '{degree_col}'. (Merges will be logged to Blueprint...)") cleaned_col_name = f'Cleaned_{degree_col}' df[cleaned_col_name] = df[degree_col].copy() unique_schools = df[inst_col].dropna().unique() if "Degree_Decisions" not in master_cache: master_cache["Degree_Decisions"] = {} school_mappings = {} # Build school-specific mappings before mutating the dataframe. for school in tqdm(unique_schools, desc=f"Mapping {degree_col}s by Institution"): school_mask = (df[inst_col] == school) & (df[degree_col].notna()) raw_degs = df.loc[school_mask, degree_col].astype(str).tolist() if not raw_degs: continue if school not in master_cache["Degree_Decisions"]: master_cache["Degree_Decisions"][school] = {} school_mappings[school] = self_cluster_degrees(raw_degs, model, master_cache["Degree_Decisions"][school], threshold) # Apply mappings and log only changed/merged values for review. for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Applying & Logging {degree_col}s"): school = row[inst_col] raw_deg = str(row[degree_col]) if pd.isna(row[degree_col]) or school not in school_mappings: continue mapping_data = school_mappings[school].get(raw_deg) if mapping_data: final_val, src, conf = mapping_data df.at[idx, cleaned_col_name] = final_val if str(raw_deg).strip() != final_val.strip() or conf != "HIGH": blueprint_data.append({ "Row_Index": idx + 3, "Column": degree_col, "Original_Raw_Text": raw_deg, "AI_Suggested_Match": final_val, "Human_Override": "", "Confidence": conf, "Match_Source": src }) return df def get_deterministic_match(value, combined_valid_targets): """Match obvious aliases/acronyms without calling embeddings or Groq.""" val_clean = normalize_text(value) for target in combined_valid_targets: target_clean = normalize_text(target) if re.search(rf"^{re.escape(val_clean)}(\b|[\s\(\/\\\-])", target_clean): return target for target in combined_valid_targets: if f"({val_clean.upper()})" in normalize_text(target).upper(): return target return None def get_top_candidates(model, value, combined_valid_targets, reference_embeddings, k=5): """Return the nearest reference candidates for one raw value.""" if not combined_valid_targets: return [] query_embedding = model.encode(value, convert_to_tensor=True) similarities = util.pytorch_cos_sim(query_embedding, reference_embeddings)[0] actual_k = min(k, len(combined_valid_targets)) top_matches = torch.topk(similarities, actual_k) return [combined_valid_targets[idx] for idx in top_matches.indices] def get_dict_exact_match(value, combined_dict): """Exact match against alias keys first, then canonical values.""" value_clean = normalize_text(value) for alias, canonical in combined_dict.items(): if normalize_text(alias) == value_clean: return canonical for canonical in combined_dict.values(): if normalize_text(canonical) == value_clean: return canonical return None def get_dict_rule_match(value, combined_dict): """Rule match dictionary-style refs while returning canonical values.""" aliases = list(combined_dict.keys()) canonical_values = list(dict.fromkeys(combined_dict.values())) alias_match = get_deterministic_match(value, aliases) if alias_match: return combined_dict[alias_match] value_match = get_deterministic_match(value, canonical_values) if value_match: return value_match return None def as_reference_list(ref_data): """Convert list/dict reference data to display values.""" if isinstance(ref_data, list): return ref_data if isinstance(ref_data, dict): return list(dict.fromkeys(ref_data.values())) return [] def as_reference_dict(ref_data): """Convert list/dict reference data to an alias-to-canonical mapping.""" if isinstance(ref_data, dict): return ref_data if isinstance(ref_data, list): return {item: item for item in ref_data if isinstance(item, str)} return {} def update_match_postfix(progress, source_counts): """Expose match-source counts in tqdm without noisy per-row prints.""" progress.set_postfix({ "Exact_Match": source_counts["Exact_Match"], "Rule_Match": source_counts["Rule_Match"], "LLM_Judged": source_counts["LLM_Judged"], }, refresh=False) def match_cache_key(column_name, value): """Return the single cache key format used for matching and reconstruction.""" if column_name in ["Institution", "Degree"]: value = strip_degrees_for_search(value) return normalize_text(str(value).rstrip(".")) def append_unique_cleaned_part(cleaned_parts, value): """Append comma-separated cleaned parts while preserving first-seen order.""" seen = set() for existing_value in cleaned_parts: for existing_part in str(existing_value).split(","): key = normalize_text(existing_part.strip()) if key: seen.add(key) added = False for part in str(value).split(","): clean_part = part.strip() if not clean_part: continue key = normalize_text(clean_part) if key in seen: continue seen.add(key) cleaned_parts.append(clean_part) added = True return added def process_column(df, column_name, model, groq_router, official_refs, manual_refs, master_cache, split_pattern, blueprint_data): """Clean one dataframe column using refs, embeddings, then Groq fallback.""" if column_name not in df.columns: return df core_data = official_refs.get(column_name, []) added_data = manual_refs.get(column_name, []) if column_name not in master_cache: master_cache[column_name] = {} detailed_cache = {} is_dict_mode = isinstance(core_data, dict) def get_updated_embeddings(): """Build current reference candidates after manual memory is loaded.""" if is_dict_mode: c_dict = {**as_reference_dict(core_data), **as_reference_dict(added_data)} c_keys = list(c_dict.keys()) u_vals = list(set(c_dict.values())) k_emb = model.encode(c_keys, convert_to_tensor=True) if c_keys else None v_emb = model.encode(u_vals, convert_to_tensor=True) if u_vals else None return c_dict, c_keys, k_emb, u_vals, v_emb else: comb = as_reference_list(core_data) + as_reference_list(added_data) comb = list(dict.fromkeys(item for item in comb if isinstance(item, str) and item.strip())) emb = model.encode(comb, convert_to_tensor=True) if comb else None return None, comb, emb, None, None combined_dict, combined_valid_targets, reference_embeddings, unique_values, value_embeddings = get_updated_embeddings() if is_dict_mode and not combined_dict: raise ValueError(f"No dictionary references loaded for '{column_name}'. Refusing to call Groq for every value.") if not is_dict_mode and not combined_valid_targets: raise ValueError(f"No list references loaded for '{column_name}'. Refusing to call Groq for every value.") # Work on unique split values first so repeated cells reuse one decision. uniques = set() for cell in df[column_name].dropna(): for p in re.split(split_pattern, str(cell), flags=re.IGNORECASE): if p.strip(): uniques.add(p.strip()) print(f"\n[INFO] Analyzing {len(uniques)} unique entities in '{column_name}'...") source_counts = Counter() progress = tqdm(sorted(uniques, key=normalize_text), desc=f"Cleaning {column_name}") for word in progress: word_clean = match_cache_key(column_name, word) # Fast path: reuse a decision made earlier in this run. if word_clean in master_cache[column_name]: detailed_cache[word_clean] = {"val": master_cache[column_name][word_clean], "src": "Memory_Cache", "conf": "HIGH"} source_counts["Memory_Cache"] += 1 update_match_postfix(progress, source_counts) continue # Exact/rule matches are trusted and avoid LLM calls. if is_dict_mode: exact = get_dict_exact_match(word, combined_dict) else: exact = next((k for k in combined_valid_targets if normalize_text(k) == normalize_text(word_clean)), None) if combined_valid_targets else None if exact: val = exact detailed_cache[word_clean] = {"val": val, "src": "Exact_Match", "conf": "HIGH"} source_counts["Exact_Match"] += 1 update_match_postfix(progress, source_counts) continue if is_dict_mode: suggested_match = get_dict_rule_match(word, combined_dict) else: suggested_match = get_deterministic_match(word, combined_valid_targets) if combined_valid_targets else None if suggested_match: detailed_cache[word_clean] = {"val": suggested_match, "src": "Rule_Match", "conf": "HIGH"} source_counts["Rule_Match"] += 1 update_match_postfix(progress, source_counts) continue # Last resort: send only the top reference candidates to Groq. candidates = [] if is_dict_mode: cand_keys = get_top_candidates(model, word, combined_valid_targets, reference_embeddings) cand_vals = get_top_candidates(model, word, unique_values, value_embeddings) candidates = list(dict.fromkeys(cand_keys + cand_vals))[:TOP_K_CANDIDATES] else: candidates = get_top_candidates(model, word, combined_valid_targets, reference_embeddings) ans_val, src, conf = groq_router.ask_judge(word, candidates, column_name) source_counts[src] += 1 update_match_postfix(progress, source_counts) # Re-check Groq output against refs so canonical casing/names are preserved. if "API_Error" not in conf and ans_val != "UNKNOWN" and ans_val != "LLM_Failed": llm_parts = [p.strip() for p in ans_val.split(",")] corrected_parts = [] all_matched = True # Flag to track if every piece exists in our data for part in llm_parts: if is_dict_mode: exact_match = get_dict_exact_match(part, combined_dict) if exact_match: corrected_parts.append(exact_match) else: rule_match = get_dict_rule_match(part, combined_dict) if rule_match: corrected_parts.append(rule_match) else: corrected_parts.append(part) all_matched = False else: exact_match = next((c for c in candidates if c.lower() == part.lower()), None) if exact_match: corrected_parts.append(exact_match) else: rule_match = get_deterministic_match(part, candidates) if rule_match: corrected_parts.append(rule_match) else: # Keep unverifiable LLM text, but do not upgrade confidence. corrected_parts.append(part) all_matched = False unique_parts = list(dict.fromkeys(corrected_parts)) ans_val = ", ".join(unique_parts) raw_parts_for_check = [ p.strip() for p in re.split(split_pattern, str(word)) if p.strip() ] or [word] raw_lookup_keys = {normalize_text(part) for part in raw_parts_for_check} verified_lookup_keys = {normalize_text(part) for part in unique_parts} # Only upgrade when the LLM answer was verified against the refs and # also matches the original text directly. Otherwise it is still an # LLM judgment, even if the chosen answer exists in the references. if all_matched and verified_lookup_keys.issubset(raw_lookup_keys): conf = "HIGH" src = "LLM_Rule_Verified" # ----------------------------------------------------- detailed_cache[word_clean] = {"val": ans_val, "src": src, "conf": conf} # Reconstruct full cell values in original row order for workbook injection. for idx, row in df.iterrows(): cell_val = row[column_name] if pd.isna(cell_val): continue raw_parts = [p.strip() for p in re.split(split_pattern, str(cell_val)) if p.strip()] cleaned_parts = [] lowest_conf = "HIGH" cell_sources = [] i = 0 while i < len(raw_parts): curr = raw_parts[i] # Recover obvious accidental splits such as "University of, Manchester". if i + 1 < len(raw_parts): combo_clean = match_cache_key(column_name, f"{curr}, {raw_parts[i+1]}") if combo_clean in detailed_cache: append_unique_cleaned_part(cleaned_parts, detailed_cache[combo_clean]["val"]) if detailed_cache[combo_clean]["conf"] != "HIGH": lowest_conf = detailed_cache[combo_clean]["conf"] cell_sources.append(detailed_cache[combo_clean]["src"]) i += 2; continue p_clean = match_cache_key(column_name, curr) if len(p_clean) <= 1: i += 1; continue cache_hit = detailed_cache.get(p_clean) if cache_hit: append_unique_cleaned_part(cleaned_parts, cache_hit["val"]) if cache_hit["conf"] != "HIGH": lowest_conf = cache_hit["conf"] cell_sources.append(cache_hit["src"]) else: append_unique_cleaned_part(cleaned_parts, smart_format(curr)) lowest_conf = "LOW" cell_sources.append("Fallback_Format") i += 1 final_stitched_val = ", ".join(cleaned_parts) df.at[idx, f'Cleaned_{column_name}'] = final_stitched_val # Review every changed cell and every low/medium-confidence result. if str(cell_val).strip() != final_stitched_val.strip() or lowest_conf != "HIGH": blueprint_data.append({ "Row_Index": idx + 3, "Column": column_name, "Original_Raw_Text": cell_val, "AI_Suggested_Match": final_stitched_val, "Human_Override": "", "Match_Source": " | ".join(set(cell_sources)), "Confidence": lowest_conf }) return df