| | from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
| | import torch |
| | import os |
| |
|
| | class DepartmentPredictor: |
| | def __init__(self, model_repo="mr-kush/sambodhan-department-classification-model", |
| | cache_dir="/app/hf_cache"): |
| | """Load model and tokenizer once at startup.""" |
| |
|
| | self.model_repo = model_repo |
| | self.cache_dir = cache_dir |
| | |
| | |
| | |
| | os.makedirs(self.cache_dir, exist_ok=True) |
| | if cache_dir is None: |
| | cache_dir = os.getenv("HF_HOME", "./hf_cache") |
| | self.cache_dir = cache_dir |
| | os.makedirs(self.cache_dir, exist_ok=True) |
| |
|
| | |
| | self.device = 0 if torch.cuda.is_available() else -1 |
| |
|
| | print(" Loading tokenizer and model...") |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(self.model_repo, cache_dir=self.cache_dir, force_download=True) |
| | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_repo, cache_dir=self.cache_dir, force_download=True) |
| |
|
| | |
| | self.classifier = pipeline( |
| | "text-classification", |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | device=self.device, |
| | top_k = None |
| | ) |
| | print(" Model and tokenizer loaded successfully.") |
| |
|
| | def predict(self, texts): |
| | """Predict departments with scores for a single text or a batch.""" |
| | if isinstance(texts, str): |
| | texts = [texts] |
| |
|
| | results = self.classifier(texts) |
| | formatted_results = [] |
| |
|
| | for preds in results: |
| | |
| | preds = sorted(preds, key=lambda x: x["score"], reverse=True) |
| | top_pred = preds[0] |
| | label = top_pred["label"] |
| | confidence = round(top_pred["score"], 4) |
| | scores_dict = {p["label"]: round(p["score"], 4) for p in preds} |
| |
|
| | formatted_results.append({ |
| | "label": label, |
| | "confidence": confidence, |
| | "scores": scores_dict |
| | }) |
| |
|
| | |
| | return formatted_results[0] if len(formatted_results) == 1 else formatted_results |
| |
|
| | @staticmethod |
| | def load_model(): |
| | """Helper to preload the model during Docker build.""" |
| | _ = DepartmentPredictor() |
| |
|