Spaces:
Sleeping
Sleeping
| # File: backend/ml_utils.py | |
| import os | |
| import re | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.path.join(BASE_DIR, "ml_models", "docusort_bert_model") | |
| _colab_model = None | |
| _tokenizer = None | |
| _id2label = None | |
| _zero_shot_classifier = None | |
| def clean_filename(filename: str): | |
| text = str(filename).lower() | |
| text = re.sub(r'[-_.()\[\]{}]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def load_colab_model(): | |
| global _colab_model, _tokenizer, _id2label | |
| if _colab_model is None: | |
| try: | |
| print(f"Loading Colab AI Model...") | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| _colab_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| _colab_model.eval() | |
| _id2label = _colab_model.config.id2label | |
| except Exception as e: | |
| print(f"❌ Error loading Colab model: {e}") | |
| def load_zero_shot(): | |
| global _zero_shot_classifier | |
| if _zero_shot_classifier is None: | |
| try: | |
| print(f"Loading Zero-Shot Generalization Model...") | |
| _zero_shot_classifier = pipeline("zero-shot-classification", model="cross-encoder/nli-distilroberta-base") | |
| except Exception as e: | |
| print(f"❌ Error loading Zero Shot model: {e}") | |
| def predict_category(filename: str, custom_rules: list = None, disabled_defaults: str = ""): | |
| cleaned_name = clean_filename(filename) | |
| # --- 1. KEYWORD OVERRIDE --- | |
| if custom_rules: | |
| for rule in custom_rules: | |
| keywords = [k.strip().lower() for k in rule.keywords.split(',') if k.strip()] | |
| if any(k in cleaned_name for k in keywords): | |
| return rule.category_name, 0.99 | |
| # --- 2. ZERO-SHOT GENERALIZATION (For Custom Categories) --- | |
| custom_categories = [r.category_name for r in custom_rules] if custom_rules else [] | |
| if custom_categories: | |
| load_zero_shot() | |
| if _zero_shot_classifier: | |
| result = _zero_shot_classifier(cleaned_name, candidate_labels=custom_categories) | |
| best_match = result['labels'][0] | |
| confidence = result['scores'][0] | |
| # Zero-shot confidence is spread across options. 35% is a very strong contextual match. | |
| if confidence > 0.35: | |
| return best_match, confidence | |
| # --- 3. COLAB MODEL FALLBACK (For Default Academic Categories) --- | |
| load_colab_model() | |
| if _colab_model is None: | |
| return "System Error", 0.0 | |
| inputs = _tokenizer(cleaned_name, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = _colab_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| top_prob, top_idx = torch.max(probs, dim=-1) | |
| confidence = top_prob.item() | |
| predicted_category = _id2label[top_idx.item()] | |
| disabled = [d.strip() for d in disabled_defaults.split(',')] if disabled_defaults else [] | |
| if predicted_category in disabled: | |
| return "Unsorted", confidence | |
| if predicted_category == "Unsorted" or confidence < 0.85: | |
| return "Unsorted", confidence | |
| return predicted_category, confidence | |
| def extract_course_code(filename: str) -> str: | |
| clean_name = re.sub(r'\b(SP|FA|SU|WI|SPRING|FALL|SUMMER|WINTER)[-_ \s]?(20)?\d{2}\b', '', filename, flags=re.IGNORECASE) | |
| match = re.search(r'\b([A-Za-z]{2,4})[-_ \s]?(\d{3,4})\b', clean_name) | |
| if match: return f"{match.group(1).upper()}-{match.group(2)}" | |
| return "General" | |
| def normalize_course_code(code: str) -> str: | |
| if not code: return None | |
| code = code.strip() | |
| match = re.match(r'^([A-Za-z]{2,4})[-_ \s]?(\d{3,4})$', code) | |
| if match: return f"{match.group(1).upper()}-{match.group(2)}" | |
| return code.upper() |