Spaces:
Running
Running
| import pandas as pd | |
| import torch | |
| import re | |
| from collections import Counter | |
| from sentence_transformers import util | |
| from tqdm import tqdm | |
| from src.utils import ( | |
| clean_degree_text, | |
| normalize_text, | |
| strip_degrees_for_search, | |
| smart_format | |
| ) | |
| from src.config import TOP_K_CANDIDATES, DEFAULT_SIMILARITY_THRESHOLD | |
| def self_cluster_degrees(raw_degrees_list, model, school_cache, threshold=0.93): | |
| """Cluster similar degree labels inside one institution.""" | |
| cleaned_list = [clean_degree_text(raw) for raw in raw_degrees_list if isinstance(raw, str)] | |
| raw_to_clean = {raw: clean_degree_text(raw) for raw in raw_degrees_list if isinstance(raw, str)} | |
| clean_counts = Counter(cleaned_list) | |
| unique_cleans = [deg for deg, count in clean_counts.most_common() if deg] | |
| raw_to_meta = {} | |
| if len(unique_cleans) <= 1: | |
| for raw, clean in raw_to_clean.items(): | |
| raw_to_meta[raw] = (clean, "Degree_Formatter", "HIGH") | |
| return raw_to_meta | |
| embeddings = model.encode(unique_cleans, convert_to_tensor=True) | |
| clean_to_clustered = {} | |
| merge_info = {} # Track similarity scores for Blueprint transparency. | |
| for i, current_deg in enumerate(unique_cleans): | |
| if current_deg in clean_to_clustered: continue | |
| clean_to_clustered[current_deg] = current_deg | |
| if i + 1 < len(unique_cleans): | |
| cos_scores = util.cos_sim(embeddings[i], embeddings[i+1:])[0] | |
| for j, score in enumerate(cos_scores): | |
| target_deg = unique_cleans[i + 1 + j] | |
| if score.item() >= threshold and target_deg not in clean_to_clustered: | |
| pair_key = f"{min(current_deg, target_deg)}|||{max(current_deg, target_deg)}" | |
| # Runtime cache avoids repeated decisions within one run only. | |
| cached_action = school_cache.get(pair_key) | |
| if cached_action: | |
| if cached_action == current_deg: | |
| clean_to_clustered[target_deg] = current_deg | |
| elif cached_action == target_deg: | |
| clean_to_clustered[current_deg] = target_deg | |
| merge_info[target_deg] = "Cached (Runtime)" | |
| else: | |
| clean_to_clustered[target_deg] = current_deg | |
| school_cache[pair_key] = current_deg | |
| merge_info[target_deg] = f"{score.item()*100:.1f}%" | |
| for raw, clean in raw_to_clean.items(): | |
| final_val = clean | |
| was_merged = False | |
| sim_str = "" | |
| while final_val in clean_to_clustered and clean_to_clustered[final_val] != final_val: | |
| if not sim_str: sim_str = merge_info.get(final_val, "") | |
| final_val = clean_to_clustered[final_val] | |
| was_merged = True | |
| conf = "MEDIUM" if was_merged else "HIGH" | |
| if was_merged: | |
| src = f"Auto-Merge ({sim_str})" if sim_str else "Auto-Merge" | |
| else: | |
| src = "Degree_Formatter" | |
| raw_to_meta[raw] = (final_val, src, conf) | |
| return raw_to_meta | |
| def cluster_degrees_by_institution(df, degree_col, inst_col, model, master_cache, blueprint_data, threshold=0.93): | |
| """Apply degree clustering separately for each institution.""" | |
| print(f"\n[INFO] Auto-Clustering '{degree_col}'. (Merges will be logged to Blueprint...)") | |
| cleaned_col_name = f'Cleaned_{degree_col}' | |
| df[cleaned_col_name] = df[degree_col].copy() | |
| unique_schools = df[inst_col].dropna().unique() | |
| if "Degree_Decisions" not in master_cache: master_cache["Degree_Decisions"] = {} | |
| school_mappings = {} | |
| # Build school-specific mappings before mutating the dataframe. | |
| for school in tqdm(unique_schools, desc=f"Mapping {degree_col}s by Institution"): | |
| school_mask = (df[inst_col] == school) & (df[degree_col].notna()) | |
| raw_degs = df.loc[school_mask, degree_col].astype(str).tolist() | |
| if not raw_degs: continue | |
| if school not in master_cache["Degree_Decisions"]: master_cache["Degree_Decisions"][school] = {} | |
| school_mappings[school] = self_cluster_degrees(raw_degs, model, master_cache["Degree_Decisions"][school], threshold) | |
| # Apply mappings and log only changed/merged values for review. | |
| for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Applying & Logging {degree_col}s"): | |
| school = row[inst_col] | |
| raw_deg = str(row[degree_col]) | |
| if pd.isna(row[degree_col]) or school not in school_mappings: continue | |
| mapping_data = school_mappings[school].get(raw_deg) | |
| if mapping_data: | |
| final_val, src, conf = mapping_data | |
| df.at[idx, cleaned_col_name] = final_val | |
| if str(raw_deg).strip() != final_val.strip() or conf != "HIGH": | |
| blueprint_data.append({ | |
| "Row_Index": idx + 3, | |
| "Column": degree_col, | |
| "Original_Raw_Text": raw_deg, | |
| "AI_Suggested_Match": final_val, | |
| "Human_Override": "", | |
| "Confidence": conf, | |
| "Match_Source": src | |
| }) | |
| return df | |
| def get_deterministic_match(value, combined_valid_targets): | |
| """Match obvious aliases/acronyms without calling embeddings or Groq.""" | |
| val_clean = normalize_text(value) | |
| for target in combined_valid_targets: | |
| target_clean = normalize_text(target) | |
| if re.search(rf"^{re.escape(val_clean)}(\b|[\s\(\/\\\-])", target_clean): return target | |
| for target in combined_valid_targets: | |
| if f"({val_clean.upper()})" in normalize_text(target).upper(): return target | |
| return None | |
| def get_top_candidates(model, value, combined_valid_targets, reference_embeddings, k=5): | |
| """Return the nearest reference candidates for one raw value.""" | |
| if not combined_valid_targets: return [] | |
| query_embedding = model.encode(value, convert_to_tensor=True) | |
| similarities = util.pytorch_cos_sim(query_embedding, reference_embeddings)[0] | |
| actual_k = min(k, len(combined_valid_targets)) | |
| top_matches = torch.topk(similarities, actual_k) | |
| return [combined_valid_targets[idx] for idx in top_matches.indices] | |
| def get_dict_exact_match(value, combined_dict): | |
| """Exact match against alias keys first, then canonical values.""" | |
| value_clean = normalize_text(value) | |
| for alias, canonical in combined_dict.items(): | |
| if normalize_text(alias) == value_clean: | |
| return canonical | |
| for canonical in combined_dict.values(): | |
| if normalize_text(canonical) == value_clean: | |
| return canonical | |
| return None | |
| def get_dict_rule_match(value, combined_dict): | |
| """Rule match dictionary-style refs while returning canonical values.""" | |
| aliases = list(combined_dict.keys()) | |
| canonical_values = list(dict.fromkeys(combined_dict.values())) | |
| alias_match = get_deterministic_match(value, aliases) | |
| if alias_match: | |
| return combined_dict[alias_match] | |
| value_match = get_deterministic_match(value, canonical_values) | |
| if value_match: | |
| return value_match | |
| return None | |
| def as_reference_list(ref_data): | |
| """Convert list/dict reference data to display values.""" | |
| if isinstance(ref_data, list): | |
| return ref_data | |
| if isinstance(ref_data, dict): | |
| return list(dict.fromkeys(ref_data.values())) | |
| return [] | |
| def as_reference_dict(ref_data): | |
| """Convert list/dict reference data to an alias-to-canonical mapping.""" | |
| if isinstance(ref_data, dict): | |
| return ref_data | |
| if isinstance(ref_data, list): | |
| return {item: item for item in ref_data if isinstance(item, str)} | |
| return {} | |
| def update_match_postfix(progress, source_counts): | |
| """Expose match-source counts in tqdm without noisy per-row prints.""" | |
| progress.set_postfix({ | |
| "Exact_Match": source_counts["Exact_Match"], | |
| "Rule_Match": source_counts["Rule_Match"], | |
| "LLM_Judged": source_counts["LLM_Judged"], | |
| }, refresh=False) | |
| def match_cache_key(column_name, value): | |
| """Return the single cache key format used for matching and reconstruction.""" | |
| if column_name in ["Institution", "Degree"]: | |
| value = strip_degrees_for_search(value) | |
| return normalize_text(str(value).rstrip(".")) | |
| def append_unique_cleaned_part(cleaned_parts, value): | |
| """Append comma-separated cleaned parts while preserving first-seen order.""" | |
| seen = set() | |
| for existing_value in cleaned_parts: | |
| for existing_part in str(existing_value).split(","): | |
| key = normalize_text(existing_part.strip()) | |
| if key: | |
| seen.add(key) | |
| added = False | |
| for part in str(value).split(","): | |
| clean_part = part.strip() | |
| if not clean_part: | |
| continue | |
| key = normalize_text(clean_part) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| cleaned_parts.append(clean_part) | |
| added = True | |
| return added | |
| def process_column(df, column_name, model, groq_router, official_refs, manual_refs, master_cache, split_pattern, blueprint_data): | |
| """Clean one dataframe column using refs, embeddings, then Groq fallback.""" | |
| if column_name not in df.columns: return df | |
| core_data = official_refs.get(column_name, []) | |
| added_data = manual_refs.get(column_name, []) | |
| if column_name not in master_cache: master_cache[column_name] = {} | |
| detailed_cache = {} | |
| is_dict_mode = isinstance(core_data, dict) | |
| def get_updated_embeddings(): | |
| """Build current reference candidates after manual memory is loaded.""" | |
| if is_dict_mode: | |
| c_dict = {**as_reference_dict(core_data), **as_reference_dict(added_data)} | |
| c_keys = list(c_dict.keys()) | |
| u_vals = list(set(c_dict.values())) | |
| k_emb = model.encode(c_keys, convert_to_tensor=True) if c_keys else None | |
| v_emb = model.encode(u_vals, convert_to_tensor=True) if u_vals else None | |
| return c_dict, c_keys, k_emb, u_vals, v_emb | |
| else: | |
| comb = as_reference_list(core_data) + as_reference_list(added_data) | |
| comb = list(dict.fromkeys(item for item in comb if isinstance(item, str) and item.strip())) | |
| emb = model.encode(comb, convert_to_tensor=True) if comb else None | |
| return None, comb, emb, None, None | |
| combined_dict, combined_valid_targets, reference_embeddings, unique_values, value_embeddings = get_updated_embeddings() | |
| if is_dict_mode and not combined_dict: | |
| raise ValueError(f"No dictionary references loaded for '{column_name}'. Refusing to call Groq for every value.") | |
| if not is_dict_mode and not combined_valid_targets: | |
| raise ValueError(f"No list references loaded for '{column_name}'. Refusing to call Groq for every value.") | |
| # Work on unique split values first so repeated cells reuse one decision. | |
| uniques = set() | |
| for cell in df[column_name].dropna(): | |
| for p in re.split(split_pattern, str(cell), flags=re.IGNORECASE): | |
| if p.strip(): uniques.add(p.strip()) | |
| print(f"\n[INFO] Analyzing {len(uniques)} unique entities in '{column_name}'...") | |
| source_counts = Counter() | |
| progress = tqdm(sorted(uniques, key=normalize_text), desc=f"Cleaning {column_name}") | |
| for word in progress: | |
| word_clean = match_cache_key(column_name, word) | |
| # Fast path: reuse a decision made earlier in this run. | |
| if word_clean in master_cache[column_name]: | |
| detailed_cache[word_clean] = {"val": master_cache[column_name][word_clean], "src": "Memory_Cache", "conf": "HIGH"} | |
| source_counts["Memory_Cache"] += 1 | |
| update_match_postfix(progress, source_counts) | |
| continue | |
| # Exact/rule matches are trusted and avoid LLM calls. | |
| if is_dict_mode: | |
| exact = get_dict_exact_match(word, combined_dict) | |
| else: | |
| exact = next((k for k in combined_valid_targets if normalize_text(k) == normalize_text(word_clean)), None) if combined_valid_targets else None | |
| if exact: | |
| val = exact | |
| detailed_cache[word_clean] = {"val": val, "src": "Exact_Match", "conf": "HIGH"} | |
| source_counts["Exact_Match"] += 1 | |
| update_match_postfix(progress, source_counts) | |
| continue | |
| if is_dict_mode: | |
| suggested_match = get_dict_rule_match(word, combined_dict) | |
| else: | |
| suggested_match = get_deterministic_match(word, combined_valid_targets) if combined_valid_targets else None | |
| if suggested_match: | |
| detailed_cache[word_clean] = {"val": suggested_match, "src": "Rule_Match", "conf": "HIGH"} | |
| source_counts["Rule_Match"] += 1 | |
| update_match_postfix(progress, source_counts) | |
| continue | |
| # Last resort: send only the top reference candidates to Groq. | |
| candidates = [] | |
| if is_dict_mode: | |
| cand_keys = get_top_candidates(model, word, combined_valid_targets, reference_embeddings) | |
| cand_vals = get_top_candidates(model, word, unique_values, value_embeddings) | |
| candidates = list(dict.fromkeys(cand_keys + cand_vals))[:TOP_K_CANDIDATES] | |
| else: | |
| candidates = get_top_candidates(model, word, combined_valid_targets, reference_embeddings) | |
| ans_val, src, conf = groq_router.ask_judge(word, candidates, column_name) | |
| source_counts[src] += 1 | |
| update_match_postfix(progress, source_counts) | |
| # Re-check Groq output against refs so canonical casing/names are preserved. | |
| if "API_Error" not in conf and ans_val != "UNKNOWN" and ans_val != "LLM_Failed": | |
| llm_parts = [p.strip() for p in ans_val.split(",")] | |
| corrected_parts = [] | |
| all_matched = True # Flag to track if every piece exists in our data | |
| for part in llm_parts: | |
| if is_dict_mode: | |
| exact_match = get_dict_exact_match(part, combined_dict) | |
| if exact_match: | |
| corrected_parts.append(exact_match) | |
| else: | |
| rule_match = get_dict_rule_match(part, combined_dict) | |
| if rule_match: | |
| corrected_parts.append(rule_match) | |
| else: | |
| corrected_parts.append(part) | |
| all_matched = False | |
| else: | |
| exact_match = next((c for c in candidates if c.lower() == part.lower()), None) | |
| if exact_match: | |
| corrected_parts.append(exact_match) | |
| else: | |
| rule_match = get_deterministic_match(part, candidates) | |
| if rule_match: | |
| corrected_parts.append(rule_match) | |
| else: | |
| # Keep unverifiable LLM text, but do not upgrade confidence. | |
| corrected_parts.append(part) | |
| all_matched = False | |
| unique_parts = list(dict.fromkeys(corrected_parts)) | |
| ans_val = ", ".join(unique_parts) | |
| raw_parts_for_check = [ | |
| p.strip() | |
| for p in re.split(split_pattern, str(word)) | |
| if p.strip() | |
| ] or [word] | |
| raw_lookup_keys = {normalize_text(part) for part in raw_parts_for_check} | |
| verified_lookup_keys = {normalize_text(part) for part in unique_parts} | |
| # Only upgrade when the LLM answer was verified against the refs and | |
| # also matches the original text directly. Otherwise it is still an | |
| # LLM judgment, even if the chosen answer exists in the references. | |
| if all_matched and verified_lookup_keys.issubset(raw_lookup_keys): | |
| conf = "HIGH" | |
| src = "LLM_Rule_Verified" | |
| # ----------------------------------------------------- | |
| detailed_cache[word_clean] = {"val": ans_val, "src": src, "conf": conf} | |
| # Reconstruct full cell values in original row order for workbook injection. | |
| for idx, row in df.iterrows(): | |
| cell_val = row[column_name] | |
| if pd.isna(cell_val): continue | |
| raw_parts = [p.strip() for p in re.split(split_pattern, str(cell_val)) if p.strip()] | |
| cleaned_parts = [] | |
| lowest_conf = "HIGH" | |
| cell_sources = [] | |
| i = 0 | |
| while i < len(raw_parts): | |
| curr = raw_parts[i] | |
| # Recover obvious accidental splits such as "University of, Manchester". | |
| if i + 1 < len(raw_parts): | |
| combo_clean = match_cache_key(column_name, f"{curr}, {raw_parts[i+1]}") | |
| if combo_clean in detailed_cache: | |
| append_unique_cleaned_part(cleaned_parts, detailed_cache[combo_clean]["val"]) | |
| if detailed_cache[combo_clean]["conf"] != "HIGH": lowest_conf = detailed_cache[combo_clean]["conf"] | |
| cell_sources.append(detailed_cache[combo_clean]["src"]) | |
| i += 2; continue | |
| p_clean = match_cache_key(column_name, curr) | |
| if len(p_clean) <= 1: i += 1; continue | |
| cache_hit = detailed_cache.get(p_clean) | |
| if cache_hit: | |
| append_unique_cleaned_part(cleaned_parts, cache_hit["val"]) | |
| if cache_hit["conf"] != "HIGH": lowest_conf = cache_hit["conf"] | |
| cell_sources.append(cache_hit["src"]) | |
| else: | |
| append_unique_cleaned_part(cleaned_parts, smart_format(curr)) | |
| lowest_conf = "LOW" | |
| cell_sources.append("Fallback_Format") | |
| i += 1 | |
| final_stitched_val = ", ".join(cleaned_parts) | |
| df.at[idx, f'Cleaned_{column_name}'] = final_stitched_val | |
| # Review every changed cell and every low/medium-confidence result. | |
| if str(cell_val).strip() != final_stitched_val.strip() or lowest_conf != "HIGH": | |
| blueprint_data.append({ | |
| "Row_Index": idx + 3, | |
| "Column": column_name, | |
| "Original_Raw_Text": cell_val, | |
| "AI_Suggested_Match": final_stitched_val, | |
| "Human_Override": "", | |
| "Match_Source": " | ".join(set(cell_sources)), | |
| "Confidence": lowest_conf | |
| }) | |
| return df | |