# File: backend/ml_utils.py import os import re import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_PATH = os.path.join(BASE_DIR, "ml_models", "docusort_bert_model") _colab_model = None _tokenizer = None _id2label = None _zero_shot_classifier = None def clean_filename(filename: str): text = str(filename).lower() text = re.sub(r'[-_.()\[\]{}]', ' ', text) text = re.sub(r'\s+', ' ', text).strip() return text def load_colab_model(): global _colab_model, _tokenizer, _id2label if _colab_model is None: try: print(f"Loading Colab AI Model...") _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) _colab_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) _colab_model.eval() _id2label = _colab_model.config.id2label except Exception as e: print(f"❌ Error loading Colab model: {e}") def load_zero_shot(): global _zero_shot_classifier if _zero_shot_classifier is None: try: print(f"Loading Zero-Shot Generalization Model...") _zero_shot_classifier = pipeline("zero-shot-classification", model="cross-encoder/nli-distilroberta-base") except Exception as e: print(f"❌ Error loading Zero Shot model: {e}") def predict_category(filename: str, custom_rules: list = None, disabled_defaults: str = ""): cleaned_name = clean_filename(filename) # --- 1. KEYWORD OVERRIDE --- if custom_rules: for rule in custom_rules: keywords = [k.strip().lower() for k in rule.keywords.split(',') if k.strip()] if any(k in cleaned_name for k in keywords): return rule.category_name, 0.99 # --- 2. ZERO-SHOT GENERALIZATION (For Custom Categories) --- custom_categories = [r.category_name for r in custom_rules] if custom_rules else [] if custom_categories: load_zero_shot() if _zero_shot_classifier: result = _zero_shot_classifier(cleaned_name, candidate_labels=custom_categories) best_match = result['labels'][0] confidence = result['scores'][0] # Zero-shot confidence is spread across options. 35% is a very strong contextual match. if confidence > 0.35: return best_match, confidence # --- 3. COLAB MODEL FALLBACK (For Default Academic Categories) --- load_colab_model() if _colab_model is None: return "System Error", 0.0 inputs = _tokenizer(cleaned_name, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = _colab_model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) top_prob, top_idx = torch.max(probs, dim=-1) confidence = top_prob.item() predicted_category = _id2label[top_idx.item()] disabled = [d.strip() for d in disabled_defaults.split(',')] if disabled_defaults else [] if predicted_category in disabled: return "Unsorted", confidence if predicted_category == "Unsorted" or confidence < 0.85: return "Unsorted", confidence return predicted_category, confidence def extract_course_code(filename: str) -> str: clean_name = re.sub(r'\b(SP|FA|SU|WI|SPRING|FALL|SUMMER|WINTER)[-_ \s]?(20)?\d{2}\b', '', filename, flags=re.IGNORECASE) match = re.search(r'\b([A-Za-z]{2,4})[-_ \s]?(\d{3,4})\b', clean_name) if match: return f"{match.group(1).upper()}-{match.group(2)}" return "General" def normalize_course_code(code: str) -> str: if not code: return None code = code.strip() match = re.match(r'^([A-Za-z]{2,4})[-_ \s]?(\d{3,4})$', code) if match: return f"{match.group(1).upper()}-{match.group(2)}" return code.upper()