File size: 3,879 Bytes
8ddf321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# File: backend/ml_utils.py
import os
import re
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(BASE_DIR, "ml_models", "docusort_bert_model")

_colab_model = None
_tokenizer = None
_id2label = None
_zero_shot_classifier = None

def clean_filename(filename: str):
    text = str(filename).lower()
    text = re.sub(r'[-_.()\[\]{}]', ' ', text) 
    text = re.sub(r'\s+', ' ', text).strip() 
    return text

def load_colab_model():
    global _colab_model, _tokenizer, _id2label
    if _colab_model is None:
        try:
            print(f"Loading Colab AI Model...")
            _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
            _colab_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
            _colab_model.eval()  
            _id2label = _colab_model.config.id2label
        except Exception as e:
            print(f"❌ Error loading Colab model: {e}")

def load_zero_shot():
    global _zero_shot_classifier
    if _zero_shot_classifier is None:
        try:
            print(f"Loading Zero-Shot Generalization Model...")
            _zero_shot_classifier = pipeline("zero-shot-classification", model="cross-encoder/nli-distilroberta-base")
        except Exception as e:
            print(f"❌ Error loading Zero Shot model: {e}")

def predict_category(filename: str, custom_rules: list = None, disabled_defaults: str = ""):
    cleaned_name = clean_filename(filename)

    # --- 1. KEYWORD OVERRIDE ---
    if custom_rules:
        for rule in custom_rules:
            keywords = [k.strip().lower() for k in rule.keywords.split(',') if k.strip()]
            if any(k in cleaned_name for k in keywords):
                return rule.category_name, 0.99

    # --- 2. ZERO-SHOT GENERALIZATION (For Custom Categories) ---
    custom_categories = [r.category_name for r in custom_rules] if custom_rules else []
    if custom_categories:
        load_zero_shot()
        if _zero_shot_classifier:
            result = _zero_shot_classifier(cleaned_name, candidate_labels=custom_categories)
            best_match = result['labels'][0]
            confidence = result['scores'][0]
            # Zero-shot confidence is spread across options. 35% is a very strong contextual match.
            if confidence > 0.35: 
                return best_match, confidence

    # --- 3. COLAB MODEL FALLBACK (For Default Academic Categories) ---
    load_colab_model()
    if _colab_model is None:
        return "System Error", 0.0

    inputs = _tokenizer(cleaned_name, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = _colab_model(**inputs)
    
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    top_prob, top_idx = torch.max(probs, dim=-1)
    
    confidence = top_prob.item()
    predicted_category = _id2label[top_idx.item()]

    disabled = [d.strip() for d in disabled_defaults.split(',')] if disabled_defaults else []
    if predicted_category in disabled:
        return "Unsorted", confidence
    
    if predicted_category == "Unsorted" or confidence < 0.85:
        return "Unsorted", confidence
    
    return predicted_category, confidence


def extract_course_code(filename: str) -> str:
    clean_name = re.sub(r'\b(SP|FA|SU|WI|SPRING|FALL|SUMMER|WINTER)[-_ \s]?(20)?\d{2}\b', '', filename, flags=re.IGNORECASE)
    match = re.search(r'\b([A-Za-z]{2,4})[-_ \s]?(\d{3,4})\b', clean_name)
    if match: return f"{match.group(1).upper()}-{match.group(2)}"
    return "General"

def normalize_course_code(code: str) -> str:
    if not code: return None
    code = code.strip()
    match = re.match(r'^([A-Za-z]{2,4})[-_ \s]?(\d{3,4})$', code)
    if match: return f"{match.group(1).upper()}-{match.group(2)}"
    return code.upper()