docusort-api / backend /ml_utils.py
Mohib
Clean backend API push
8ddf321
Raw
History Blame Contribute Delete
3.88 kB
# File: backend/ml_utils.py
import os
import re
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(BASE_DIR, "ml_models", "docusort_bert_model")
_colab_model = None
_tokenizer = None
_id2label = None
_zero_shot_classifier = None
def clean_filename(filename: str):
text = str(filename).lower()
text = re.sub(r'[-_.()\[\]{}]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def load_colab_model():
global _colab_model, _tokenizer, _id2label
if _colab_model is None:
try:
print(f"Loading Colab AI Model...")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
_colab_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
_colab_model.eval()
_id2label = _colab_model.config.id2label
except Exception as e:
print(f"❌ Error loading Colab model: {e}")
def load_zero_shot():
global _zero_shot_classifier
if _zero_shot_classifier is None:
try:
print(f"Loading Zero-Shot Generalization Model...")
_zero_shot_classifier = pipeline("zero-shot-classification", model="cross-encoder/nli-distilroberta-base")
except Exception as e:
print(f"❌ Error loading Zero Shot model: {e}")
def predict_category(filename: str, custom_rules: list = None, disabled_defaults: str = ""):
cleaned_name = clean_filename(filename)
# --- 1. KEYWORD OVERRIDE ---
if custom_rules:
for rule in custom_rules:
keywords = [k.strip().lower() for k in rule.keywords.split(',') if k.strip()]
if any(k in cleaned_name for k in keywords):
return rule.category_name, 0.99
# --- 2. ZERO-SHOT GENERALIZATION (For Custom Categories) ---
custom_categories = [r.category_name for r in custom_rules] if custom_rules else []
if custom_categories:
load_zero_shot()
if _zero_shot_classifier:
result = _zero_shot_classifier(cleaned_name, candidate_labels=custom_categories)
best_match = result['labels'][0]
confidence = result['scores'][0]
# Zero-shot confidence is spread across options. 35% is a very strong contextual match.
if confidence > 0.35:
return best_match, confidence
# --- 3. COLAB MODEL FALLBACK (For Default Academic Categories) ---
load_colab_model()
if _colab_model is None:
return "System Error", 0.0
inputs = _tokenizer(cleaned_name, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = _colab_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
top_prob, top_idx = torch.max(probs, dim=-1)
confidence = top_prob.item()
predicted_category = _id2label[top_idx.item()]
disabled = [d.strip() for d in disabled_defaults.split(',')] if disabled_defaults else []
if predicted_category in disabled:
return "Unsorted", confidence
if predicted_category == "Unsorted" or confidence < 0.85:
return "Unsorted", confidence
return predicted_category, confidence
def extract_course_code(filename: str) -> str:
clean_name = re.sub(r'\b(SP|FA|SU|WI|SPRING|FALL|SUMMER|WINTER)[-_ \s]?(20)?\d{2}\b', '', filename, flags=re.IGNORECASE)
match = re.search(r'\b([A-Za-z]{2,4})[-_ \s]?(\d{3,4})\b', clean_name)
if match: return f"{match.group(1).upper()}-{match.group(2)}"
return "General"
def normalize_course_code(code: str) -> str:
if not code: return None
code = code.strip()
match = re.match(r'^([A-Za-z]{2,4})[-_ \s]?(\d{3,4})$', code)
if match: return f"{match.group(1).upper()}-{match.group(2)}"
return code.upper()