Spaces:
Runtime error
Runtime error
| from app.core.config import settings | |
| import logging | |
| import os | |
| import zipfile | |
| import hashlib | |
| logger = logging.getLogger(__name__) | |
| class LegalBertService: | |
| def __init__(self): | |
| self.device = "cpu" | |
| self.tokenizer = None | |
| self.model = None | |
| self._load_model() | |
| def _extract_model_from_zip(self, zipPath: str, extractPath: str): | |
| """Extract LegalBERT model from zip file""" | |
| try: | |
| if not os.path.exists(zipPath): | |
| logger.warning(f"Model zip file not found: {zipPath}") | |
| return False | |
| if not os.path.exists(extractPath): | |
| os.makedirs(extractPath) | |
| logger.info(f"Created model directory: {extractPath}") | |
| if os.path.exists(os.path.join(extractPath, "config.json")): | |
| logger.info("Model already extracted") | |
| return True | |
| logger.info(f"Extracting model from {zipPath} to {extractPath}") | |
| with zipfile.ZipFile(zipPath, 'r') as zipRef: | |
| zipRef.extractall(extractPath) | |
| logger.info("Model extraction completed") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to extract model: {str(e)}") | |
| return False | |
| def _load_model(self): | |
| try: | |
| zipPath = os.path.join("./models", "legalbert_epoch4.zip") | |
| if os.path.exists(zipPath): | |
| if self._extract_model_from_zip(zipPath, settings.legal_bert_model_path): | |
| logger.info("Model zip file found and extracted") | |
| if os.path.exists(settings.legal_bert_model_path) and os.path.exists(os.path.join(settings.legal_bert_model_path, "config.json")): | |
| try: | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Loading LegalBERT model from {settings.legal_bert_model_path}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(settings.legal_bert_model_path) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| settings.legal_bert_model_path | |
| ).to(self.device) | |
| logger.info(f"LegalBERT model loaded successfully on {self.device}") | |
| except ImportError: | |
| logger.warning("torch/transformers not installed - using placeholder mode") | |
| except Exception as e: | |
| logger.error(f"Failed to load actual model: {str(e)}") | |
| else: | |
| logger.warning(f"LegalBERT model files not found in: {settings.legal_bert_model_path}") | |
| logger.info("Place your legalbert_epoch4.zip in ./models/ or model files directly in ./models/legalbert_model/") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize LegalBERT service: {str(e)}") | |
| def predictVerdict(self, inputText: str) -> str: | |
| if not self.is_model_loaded(): | |
| logger.info("Using placeholder verdict prediction") | |
| textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16) | |
| return "guilty" if textHash % 2 == 1 else "not guilty" | |
| try: | |
| import torch | |
| import torch.nn.functional as F | |
| inputs = self.tokenizer( | |
| inputText, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits | |
| probabilities = F.softmax(logits, dim=1) | |
| predictedLabel = torch.argmax(probabilities, dim=1).item() | |
| return "guilty" if predictedLabel == 1 else "not guilty" | |
| except Exception as e: | |
| logger.error(f"Error predicting verdict: {str(e)}") | |
| return "not guilty" | |
| def getConfidence(self, inputText: str) -> float: | |
| if not self.is_model_loaded(): | |
| logger.info("Using placeholder confidence score") | |
| textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16) | |
| return 0.5 + (textHash % 100) / 200.0 | |
| try: | |
| import torch | |
| import torch.nn.functional as F | |
| inputs = self.tokenizer( | |
| inputText, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits | |
| probabilities = F.softmax(logits, dim=1) | |
| return float(torch.max(probabilities).item()) | |
| except Exception as e: | |
| logger.error(f"Error getting confidence: {str(e)}") | |
| return 0.5 | |
| def is_model_loaded(self) -> bool: | |
| return self.model is not None and self.tokenizer is not None | |
| def get_device(self) -> str: | |
| return str(self.device) | |
| def is_healthy(self) -> bool: | |
| return True | |