mr-kush commited on
Commit
b3cf3ad
·
1 Parent(s): 289081c

Add DepartmentPredictor class for department classification model

Browse files
Files changed (1) hide show
  1. predict_dept_model.py +61 -0
predict_dept_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
2
+ import torch
3
+ import os
4
+
5
+ class DepartmentPredictor:
6
+ def __init__(self, model_repo="mr-kush/sambodhan-department-classification-model",
7
+ cache_dir="/app/model_cache"):
8
+ """Load model and tokenizer once at startup."""
9
+ self.model_repo = model_repo
10
+ self.cache_dir = cache_dir
11
+
12
+ # Ensure cache folder exists
13
+ os.makedirs(self.cache_dir, exist_ok=True)
14
+
15
+ # Device selection
16
+ self.device = 0 if torch.cuda.is_available() else -1
17
+
18
+ print("🔄 Loading tokenizer and model...")
19
+ # Load tokenizer and model
20
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_repo, cache_dir=self.cache_dir)
21
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_repo, cache_dir=self.cache_dir)
22
+
23
+ # Create classification pipeline
24
+ self.classifier = pipeline(
25
+ "text-classification",
26
+ model=self.model,
27
+ tokenizer=self.tokenizer,
28
+ device=self.device,
29
+ return_all_scores=True
30
+ )
31
+ print("✅ Model and tokenizer loaded successfully.")
32
+
33
+ def predict(self, texts):
34
+ """Predict departments with scores for a single text or a batch."""
35
+ if isinstance(texts, str):
36
+ texts = [texts]
37
+
38
+ results = self.classifier(texts)
39
+ formatted_results = []
40
+
41
+ for preds in results:
42
+ # Sort by descending confidence
43
+ preds = sorted(preds, key=lambda x: x["score"], reverse=True)
44
+ top_pred = preds[0]
45
+ label = top_pred["label"]
46
+ confidence = round(top_pred["score"], 4)
47
+ scores_dict = {p["label"]: round(p["score"], 4) for p in preds}
48
+
49
+ formatted_results.append({
50
+ "label": label,
51
+ "confidence": confidence,
52
+ "scores": scores_dict
53
+ })
54
+
55
+ # Return single dict if only one input
56
+ return formatted_results[0] if len(formatted_results) == 1 else formatted_results
57
+
58
+ @staticmethod
59
+ def load_model():
60
+ """Helper to preload the model during Docker build."""
61
+ _ = DepartmentPredictor()