Spaces:
Sleeping
Sleeping
b24122
commited on
Commit
Β·
9ae222c
1
Parent(s):
9844436
Add LegalBERT model loading from zip and direct files for case analysis
Browse filesImplement LegalBERT model loading from zip and directory; update `predictVerdict` in `LegalBertService` and documentation.
Replit-Commit-Author: Agent
Replit-Commit-Session-Id: 63975d62-3d3b-48af-8685-b7e915f31f2b
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/a5a12774-3181-414d-89e4-a4da8e3fb1ca/63975d62-3d3b-48af-8685-b7e915f31f2b/i8A93Md
- app/api/routes.py +1 -1
- app/services/legal_bert.py +109 -22
- models/README.md +20 -4
app/api/routes.py
CHANGED
|
@@ -42,7 +42,7 @@ async def analyze_case(request: CaseAnalysisRequest):
|
|
| 42 |
logger.info(f"Analyzing case with text length: {len(request.caseText)}")
|
| 43 |
|
| 44 |
# Step 1: Get initial verdict from LegalBERT
|
| 45 |
-
initial_verdict = legal_bert_service.
|
| 46 |
confidence = legal_bert_service.getConfidence(request.caseText)
|
| 47 |
|
| 48 |
logger.info(f"Initial verdict: {initial_verdict}, confidence: {confidence}")
|
|
|
|
| 42 |
logger.info(f"Analyzing case with text length: {len(request.caseText)}")
|
| 43 |
|
| 44 |
# Step 1: Get initial verdict from LegalBERT
|
| 45 |
+
initial_verdict = legal_bert_service.predictVerdict(request.caseText)
|
| 46 |
confidence = legal_bert_service.getConfidence(request.caseText)
|
| 47 |
|
| 48 |
logger.info(f"Initial verdict: {initial_verdict}, confidence: {confidence}")
|
app/services/legal_bert.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from app.core.config import settings
|
| 2 |
import logging
|
| 3 |
import os
|
|
|
|
|
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
|
@@ -11,45 +13,130 @@ class LegalBertService:
|
|
| 11 |
self.model = None
|
| 12 |
self._load_model()
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def _load_model(self):
|
| 15 |
try:
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
else:
|
| 21 |
-
logger.warning(f"LegalBERT model
|
| 22 |
-
logger.info("
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
-
logger.error(f"Failed to
|
| 25 |
|
| 26 |
-
def
|
| 27 |
if not self.is_model_loaded():
|
| 28 |
-
# Return placeholder prediction for development
|
| 29 |
logger.info("Using placeholder verdict prediction")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
return "guilty" if text_hash % 2 == 1 else "not guilty"
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def getConfidence(self, inputText: str) -> float:
|
| 38 |
if not self.is_model_loaded():
|
| 39 |
-
# Return placeholder confidence for development
|
| 40 |
logger.info("Using placeholder confidence score")
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
return 0.5 + (text_hash % 100) / 200.0 # Returns 0.5-0.99
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def is_model_loaded(self) -> bool:
|
| 49 |
-
return
|
| 50 |
|
| 51 |
def get_device(self) -> str:
|
| 52 |
return str(self.device)
|
| 53 |
|
| 54 |
def is_healthy(self) -> bool:
|
| 55 |
-
return True
|
|
|
|
| 1 |
from app.core.config import settings
|
| 2 |
import logging
|
| 3 |
import os
|
| 4 |
+
import zipfile
|
| 5 |
+
import hashlib
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
|
|
|
| 13 |
self.model = None
|
| 14 |
self._load_model()
|
| 15 |
|
| 16 |
+
def _extract_model_from_zip(self, zipPath: str, extractPath: str):
|
| 17 |
+
"""Extract LegalBERT model from zip file"""
|
| 18 |
+
try:
|
| 19 |
+
if not os.path.exists(zipPath):
|
| 20 |
+
logger.warning(f"Model zip file not found: {zipPath}")
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(extractPath):
|
| 24 |
+
os.makedirs(extractPath)
|
| 25 |
+
logger.info(f"Created model directory: {extractPath}")
|
| 26 |
+
|
| 27 |
+
# Check if model is already extracted
|
| 28 |
+
if os.path.exists(os.path.join(extractPath, "config.json")):
|
| 29 |
+
logger.info("Model already extracted")
|
| 30 |
+
return True
|
| 31 |
+
|
| 32 |
+
logger.info(f"Extracting model from {zipPath} to {extractPath}")
|
| 33 |
+
with zipfile.ZipFile(zipPath, 'r') as zipRef:
|
| 34 |
+
zipRef.extractall(extractPath)
|
| 35 |
+
|
| 36 |
+
logger.info("Model extraction completed")
|
| 37 |
+
return True
|
| 38 |
+
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Failed to extract model: {str(e)}")
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
def _load_model(self):
|
| 44 |
try:
|
| 45 |
+
# Check for zip file first
|
| 46 |
+
zipPath = os.path.join("./models", "legalbert_epoch4.zip")
|
| 47 |
+
|
| 48 |
+
if os.path.exists(zipPath):
|
| 49 |
+
if self._extract_model_from_zip(zipPath, settings.legal_bert_model_path):
|
| 50 |
+
logger.info("Model zip file found and extracted")
|
| 51 |
+
|
| 52 |
+
# Try to load the actual model
|
| 53 |
+
if os.path.exists(settings.legal_bert_model_path) and os.path.exists(os.path.join(settings.legal_bert_model_path, "config.json")):
|
| 54 |
+
try:
|
| 55 |
+
import torch
|
| 56 |
+
import torch.nn.functional as F
|
| 57 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 58 |
+
|
| 59 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
logger.info(f"Loading LegalBERT model from {settings.legal_bert_model_path}")
|
| 61 |
+
|
| 62 |
+
self.tokenizer = AutoTokenizer.from_pretrained(settings.legal_bert_model_path)
|
| 63 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 64 |
+
settings.legal_bert_model_path
|
| 65 |
+
).to(self.device)
|
| 66 |
+
|
| 67 |
+
logger.info(f"LegalBERT model loaded successfully on {self.device}")
|
| 68 |
+
|
| 69 |
+
except ImportError:
|
| 70 |
+
logger.warning("torch/transformers not installed - using placeholder mode")
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"Failed to load actual model: {str(e)}")
|
| 73 |
else:
|
| 74 |
+
logger.warning(f"LegalBERT model files not found in: {settings.legal_bert_model_path}")
|
| 75 |
+
logger.info("Place your legalbert_epoch4.zip in ./models/ or model files directly in ./models/legalbert_model/")
|
| 76 |
+
|
| 77 |
except Exception as e:
|
| 78 |
+
logger.error(f"Failed to initialize LegalBERT service: {str(e)}")
|
| 79 |
|
| 80 |
+
def predictVerdict(self, inputText: str) -> str:
|
| 81 |
if not self.is_model_loaded():
|
|
|
|
| 82 |
logger.info("Using placeholder verdict prediction")
|
| 83 |
+
textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
|
| 84 |
+
return "guilty" if textHash % 2 == 1 else "not guilty"
|
|
|
|
| 85 |
|
| 86 |
+
try:
|
| 87 |
+
import torch
|
| 88 |
+
import torch.nn.functional as F
|
| 89 |
+
|
| 90 |
+
inputs = self.tokenizer(
|
| 91 |
+
inputText,
|
| 92 |
+
return_tensors="pt",
|
| 93 |
+
truncation=True,
|
| 94 |
+
padding=True
|
| 95 |
+
).to(self.device)
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
logits = self.model(**inputs).logits
|
| 99 |
+
probabilities = F.softmax(logits, dim=1)
|
| 100 |
+
predictedLabel = torch.argmax(probabilities, dim=1).item()
|
| 101 |
+
|
| 102 |
+
return "guilty" if predictedLabel == 1 else "not guilty"
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"Error predicting verdict: {str(e)}")
|
| 106 |
+
return "not guilty"
|
| 107 |
|
| 108 |
def getConfidence(self, inputText: str) -> float:
|
| 109 |
if not self.is_model_loaded():
|
|
|
|
| 110 |
logger.info("Using placeholder confidence score")
|
| 111 |
+
textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
|
| 112 |
+
return 0.5 + (textHash % 100) / 200.0
|
|
|
|
| 113 |
|
| 114 |
+
try:
|
| 115 |
+
import torch
|
| 116 |
+
import torch.nn.functional as F
|
| 117 |
+
|
| 118 |
+
inputs = self.tokenizer(
|
| 119 |
+
inputText,
|
| 120 |
+
return_tensors="pt",
|
| 121 |
+
truncation=True,
|
| 122 |
+
padding=True
|
| 123 |
+
).to(self.device)
|
| 124 |
+
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
logits = self.model(**inputs).logits
|
| 127 |
+
probabilities = F.softmax(logits, dim=1)
|
| 128 |
+
|
| 129 |
+
return float(torch.max(probabilities).item())
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Error getting confidence: {str(e)}")
|
| 133 |
+
return 0.5
|
| 134 |
|
| 135 |
def is_model_loaded(self) -> bool:
|
| 136 |
+
return self.model is not None and self.tokenizer is not None
|
| 137 |
|
| 138 |
def get_device(self) -> str:
|
| 139 |
return str(self.device)
|
| 140 |
|
| 141 |
def is_healthy(self) -> bool:
|
| 142 |
+
return True
|
models/README.md
CHANGED
|
@@ -2,8 +2,19 @@
|
|
| 2 |
|
| 3 |
## LegalBERT Model
|
| 4 |
|
| 5 |
-
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
```
|
| 8 |
models/
|
| 9 |
βββ legalbert_model/
|
|
@@ -14,8 +25,6 @@ models/
|
|
| 14 |
βββ vocab.txt
|
| 15 |
```
|
| 16 |
|
| 17 |
-
The model should be compatible with Hugging Face transformers library and fine-tuned for legal text classification.
|
| 18 |
-
|
| 19 |
## Installation
|
| 20 |
|
| 21 |
Once you have the model files:
|
|
@@ -32,4 +41,11 @@ Once you have the model files:
|
|
| 32 |
- Should output binary classification (guilty/not guilty)
|
| 33 |
- Compatible with AutoModelForSequenceClassification
|
| 34 |
- Supports text truncation and padding
|
| 35 |
-
- Returns logits that can be converted to probabilities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
## LegalBERT Model
|
| 4 |
|
| 5 |
+
You can add your fine-tuned LegalBERT model in two ways:
|
| 6 |
|
| 7 |
+
### Option 1: Zip File (Recommended)
|
| 8 |
+
Place your model zip file as `legalbert_epoch4.zip` in this directory:
|
| 9 |
+
```
|
| 10 |
+
models/
|
| 11 |
+
βββ legalbert_epoch4.zip
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
The system will automatically extract it to `legalbert_model/` when the server starts.
|
| 15 |
+
|
| 16 |
+
### Option 2: Direct Files
|
| 17 |
+
Place your LegalBERT model files directly in the `legalbert_model/` subdirectory:
|
| 18 |
```
|
| 19 |
models/
|
| 20 |
βββ legalbert_model/
|
|
|
|
| 25 |
βββ vocab.txt
|
| 26 |
```
|
| 27 |
|
|
|
|
|
|
|
| 28 |
## Installation
|
| 29 |
|
| 30 |
Once you have the model files:
|
|
|
|
| 41 |
- Should output binary classification (guilty/not guilty)
|
| 42 |
- Compatible with AutoModelForSequenceClassification
|
| 43 |
- Supports text truncation and padding
|
| 44 |
+
- Returns logits that can be converted to probabilities
|
| 45 |
+
|
| 46 |
+
## Auto-Detection
|
| 47 |
+
|
| 48 |
+
The service checks for models in this order:
|
| 49 |
+
1. `legalbert_epoch4.zip` (extracts automatically)
|
| 50 |
+
2. `legalbert_model/` directory with model files
|
| 51 |
+
3. Falls back to placeholder mode if neither found
|