b24122 commited on
Commit
9ae222c
Β·
1 Parent(s): 9844436

Add LegalBERT model loading from zip and direct files for case analysis

Browse files

Implement LegalBERT model loading from zip and directory; update `predictVerdict` in `LegalBertService` and documentation.

Replit-Commit-Author: Agent
Replit-Commit-Session-Id: 63975d62-3d3b-48af-8685-b7e915f31f2b
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/a5a12774-3181-414d-89e4-a4da8e3fb1ca/63975d62-3d3b-48af-8685-b7e915f31f2b/i8A93Md

Files changed (3) hide show
  1. app/api/routes.py +1 -1
  2. app/services/legal_bert.py +109 -22
  3. models/README.md +20 -4
app/api/routes.py CHANGED
@@ -42,7 +42,7 @@ async def analyze_case(request: CaseAnalysisRequest):
42
  logger.info(f"Analyzing case with text length: {len(request.caseText)}")
43
 
44
  # Step 1: Get initial verdict from LegalBERT
45
- initial_verdict = legal_bert_service.predict_verdict(request.caseText)
46
  confidence = legal_bert_service.getConfidence(request.caseText)
47
 
48
  logger.info(f"Initial verdict: {initial_verdict}, confidence: {confidence}")
 
42
  logger.info(f"Analyzing case with text length: {len(request.caseText)}")
43
 
44
  # Step 1: Get initial verdict from LegalBERT
45
+ initial_verdict = legal_bert_service.predictVerdict(request.caseText)
46
  confidence = legal_bert_service.getConfidence(request.caseText)
47
 
48
  logger.info(f"Initial verdict: {initial_verdict}, confidence: {confidence}")
app/services/legal_bert.py CHANGED
@@ -1,6 +1,8 @@
1
  from app.core.config import settings
2
  import logging
3
  import os
 
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
@@ -11,45 +13,130 @@ class LegalBertService:
11
  self.model = None
12
  self._load_model()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def _load_model(self):
15
  try:
16
- if os.path.exists(settings.legal_bert_model_path):
17
- logger.info(f"LegalBERT model path found: {settings.legal_bert_model_path}")
18
- # TODO: Load actual model when torch/transformers are available
19
- logger.info("Model loading placeholder - install torch and transformers to enable")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  else:
21
- logger.warning(f"LegalBERT model path does not exist: {settings.legal_bert_model_path}")
22
- logger.info("Model will be loaded when files are available")
 
23
  except Exception as e:
24
- logger.error(f"Failed to load LegalBERT model: {str(e)}")
25
 
26
- def predict_verdict(self, inputText: str) -> str:
27
  if not self.is_model_loaded():
28
- # Return placeholder prediction for development
29
  logger.info("Using placeholder verdict prediction")
30
- import hashlib
31
- text_hash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
32
- return "guilty" if text_hash % 2 == 1 else "not guilty"
33
 
34
- # TODO: Implement actual prediction when model is loaded
35
- return "not guilty"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def getConfidence(self, inputText: str) -> float:
38
  if not self.is_model_loaded():
39
- # Return placeholder confidence for development
40
  logger.info("Using placeholder confidence score")
41
- import hashlib
42
- text_hash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
43
- return 0.5 + (text_hash % 100) / 200.0 # Returns 0.5-0.99
44
 
45
- # TODO: Implement actual confidence when model is loaded
46
- return 0.75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def is_model_loaded(self) -> bool:
49
- return False # Always False until actual model is loaded
50
 
51
  def get_device(self) -> str:
52
  return str(self.device)
53
 
54
  def is_healthy(self) -> bool:
55
- return True # Always healthy for placeholder implementation
 
1
  from app.core.config import settings
2
  import logging
3
  import os
4
+ import zipfile
5
+ import hashlib
6
 
7
  logger = logging.getLogger(__name__)
8
 
 
13
  self.model = None
14
  self._load_model()
15
 
16
+ def _extract_model_from_zip(self, zipPath: str, extractPath: str):
17
+ """Extract LegalBERT model from zip file"""
18
+ try:
19
+ if not os.path.exists(zipPath):
20
+ logger.warning(f"Model zip file not found: {zipPath}")
21
+ return False
22
+
23
+ if not os.path.exists(extractPath):
24
+ os.makedirs(extractPath)
25
+ logger.info(f"Created model directory: {extractPath}")
26
+
27
+ # Check if model is already extracted
28
+ if os.path.exists(os.path.join(extractPath, "config.json")):
29
+ logger.info("Model already extracted")
30
+ return True
31
+
32
+ logger.info(f"Extracting model from {zipPath} to {extractPath}")
33
+ with zipfile.ZipFile(zipPath, 'r') as zipRef:
34
+ zipRef.extractall(extractPath)
35
+
36
+ logger.info("Model extraction completed")
37
+ return True
38
+
39
+ except Exception as e:
40
+ logger.error(f"Failed to extract model: {str(e)}")
41
+ return False
42
+
43
  def _load_model(self):
44
  try:
45
+ # Check for zip file first
46
+ zipPath = os.path.join("./models", "legalbert_epoch4.zip")
47
+
48
+ if os.path.exists(zipPath):
49
+ if self._extract_model_from_zip(zipPath, settings.legal_bert_model_path):
50
+ logger.info("Model zip file found and extracted")
51
+
52
+ # Try to load the actual model
53
+ if os.path.exists(settings.legal_bert_model_path) and os.path.exists(os.path.join(settings.legal_bert_model_path, "config.json")):
54
+ try:
55
+ import torch
56
+ import torch.nn.functional as F
57
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
58
+
59
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ logger.info(f"Loading LegalBERT model from {settings.legal_bert_model_path}")
61
+
62
+ self.tokenizer = AutoTokenizer.from_pretrained(settings.legal_bert_model_path)
63
+ self.model = AutoModelForSequenceClassification.from_pretrained(
64
+ settings.legal_bert_model_path
65
+ ).to(self.device)
66
+
67
+ logger.info(f"LegalBERT model loaded successfully on {self.device}")
68
+
69
+ except ImportError:
70
+ logger.warning("torch/transformers not installed - using placeholder mode")
71
+ except Exception as e:
72
+ logger.error(f"Failed to load actual model: {str(e)}")
73
  else:
74
+ logger.warning(f"LegalBERT model files not found in: {settings.legal_bert_model_path}")
75
+ logger.info("Place your legalbert_epoch4.zip in ./models/ or model files directly in ./models/legalbert_model/")
76
+
77
  except Exception as e:
78
+ logger.error(f"Failed to initialize LegalBERT service: {str(e)}")
79
 
80
+ def predictVerdict(self, inputText: str) -> str:
81
  if not self.is_model_loaded():
 
82
  logger.info("Using placeholder verdict prediction")
83
+ textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
84
+ return "guilty" if textHash % 2 == 1 else "not guilty"
 
85
 
86
+ try:
87
+ import torch
88
+ import torch.nn.functional as F
89
+
90
+ inputs = self.tokenizer(
91
+ inputText,
92
+ return_tensors="pt",
93
+ truncation=True,
94
+ padding=True
95
+ ).to(self.device)
96
+
97
+ with torch.no_grad():
98
+ logits = self.model(**inputs).logits
99
+ probabilities = F.softmax(logits, dim=1)
100
+ predictedLabel = torch.argmax(probabilities, dim=1).item()
101
+
102
+ return "guilty" if predictedLabel == 1 else "not guilty"
103
+
104
+ except Exception as e:
105
+ logger.error(f"Error predicting verdict: {str(e)}")
106
+ return "not guilty"
107
 
108
  def getConfidence(self, inputText: str) -> float:
109
  if not self.is_model_loaded():
 
110
  logger.info("Using placeholder confidence score")
111
+ textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
112
+ return 0.5 + (textHash % 100) / 200.0
 
113
 
114
+ try:
115
+ import torch
116
+ import torch.nn.functional as F
117
+
118
+ inputs = self.tokenizer(
119
+ inputText,
120
+ return_tensors="pt",
121
+ truncation=True,
122
+ padding=True
123
+ ).to(self.device)
124
+
125
+ with torch.no_grad():
126
+ logits = self.model(**inputs).logits
127
+ probabilities = F.softmax(logits, dim=1)
128
+
129
+ return float(torch.max(probabilities).item())
130
+
131
+ except Exception as e:
132
+ logger.error(f"Error getting confidence: {str(e)}")
133
+ return 0.5
134
 
135
  def is_model_loaded(self) -> bool:
136
+ return self.model is not None and self.tokenizer is not None
137
 
138
  def get_device(self) -> str:
139
  return str(self.device)
140
 
141
  def is_healthy(self) -> bool:
142
+ return True
models/README.md CHANGED
@@ -2,8 +2,19 @@
2
 
3
  ## LegalBERT Model
4
 
5
- Place your LegalBERT model files in the `legalbert_model/` subdirectory:
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  ```
8
  models/
9
  └── legalbert_model/
@@ -14,8 +25,6 @@ models/
14
  └── vocab.txt
15
  ```
16
 
17
- The model should be compatible with Hugging Face transformers library and fine-tuned for legal text classification.
18
-
19
  ## Installation
20
 
21
  Once you have the model files:
@@ -32,4 +41,11 @@ Once you have the model files:
32
  - Should output binary classification (guilty/not guilty)
33
  - Compatible with AutoModelForSequenceClassification
34
  - Supports text truncation and padding
35
- - Returns logits that can be converted to probabilities
 
 
 
 
 
 
 
 
2
 
3
  ## LegalBERT Model
4
 
5
+ You can add your fine-tuned LegalBERT model in two ways:
6
 
7
+ ### Option 1: Zip File (Recommended)
8
+ Place your model zip file as `legalbert_epoch4.zip` in this directory:
9
+ ```
10
+ models/
11
+ └── legalbert_epoch4.zip
12
+ ```
13
+
14
+ The system will automatically extract it to `legalbert_model/` when the server starts.
15
+
16
+ ### Option 2: Direct Files
17
+ Place your LegalBERT model files directly in the `legalbert_model/` subdirectory:
18
  ```
19
  models/
20
  └── legalbert_model/
 
25
  └── vocab.txt
26
  ```
27
 
 
 
28
  ## Installation
29
 
30
  Once you have the model files:
 
41
  - Should output binary classification (guilty/not guilty)
42
  - Compatible with AutoModelForSequenceClassification
43
  - Supports text truncation and padding
44
+ - Returns logits that can be converted to probabilities
45
+
46
+ ## Auto-Detection
47
+
48
+ The service checks for models in this order:
49
+ 1. `legalbert_epoch4.zip` (extracts automatically)
50
+ 2. `legalbert_model/` directory with model files
51
+ 3. Falls back to placeholder mode if neither found