harsh580g commited on
Commit
876467d
·
verified ·
1 Parent(s): ed0b11f

Upload 2 files

Browse files
legal_inference/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .inference import LegalSectionRetriever
2
+
legal_inference/inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertForSequenceClassification
2
+ import torch, json, os
3
+
4
+ class LegalSectionRetriever:
5
+ def __init__(self, repo_id: str = "harsh580g/bert-query-section"):
6
+ # Load model + tokenizer from HF
7
+ self.model = BertForSequenceClassification.from_pretrained(repo_id)
8
+ self.tokenizer = BertTokenizer.from_pretrained(repo_id)
9
+ self.model.eval()
10
+
11
+ # Load threshold (from config.json or external json)
12
+ self.threshold = getattr(self.model.config, "threshold", 0.5) # fallback to 0.5
13
+
14
+ # Load sections
15
+ sections_path = os.path.join(os.path.dirname(__file__), "sections.json")
16
+ if not os.path.exists(sections_path):
17
+ raise FileNotFoundError("sections.json missing. Please download from HF repo.")
18
+ with open(sections_path, "r") as f:
19
+ self.sections = json.load(f)
20
+
21
+ def get_relevant_sections(self, query: str):
22
+ results = []
23
+ for sec_id, sec_text in self.sections.items():
24
+ inputs = self.tokenizer(query, sec_text, return_tensors="pt", padding=True, truncation=True)
25
+ with torch.no_grad():
26
+ logits = self.model(**inputs).logits.squeeze(-1)
27
+ prob = torch.sigmoid(logits).item()
28
+ if prob >= self.threshold:
29
+ results.append((sec_id, prob))
30
+
31
+ results.sort(key=lambda x: x[1], reverse=True)
32
+ return results
33
+