Syamchand commited on
Commit
cb0602a
·
verified ·
1 Parent(s): fa0788d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -73
app.py CHANGED
@@ -1,23 +1,97 @@
1
  import torch
2
- from contextlib import asynccontextmanager
3
- from fastapi import FastAPI, HTTPException
4
- from pydantic import BaseModel
5
- import torch
6
  from contextlib import asynccontextmanager
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
10
  from sentence_transformers import SentenceTransformer
 
11
  from setfit import SetFitModel
12
- import numpy as np
13
- from typing import List
14
- from sentence_transformers import SentenceTransformer
15
- from setfit import SetFitModel
16
- import numpy as np
17
  from typing import List
18
 
19
  models = {}
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
23
  print("Loading models...")
@@ -29,7 +103,7 @@ async def lifespan(app: FastAPI):
29
  )
30
  print("✓ contracts_clauses loaded")
31
 
32
- # 2. Contract NLI (BERT-base — AutoTokenizer works fine)
33
  print("Loading contract NLI model...")
34
  models["nli_tokenizer"] = AutoTokenizer.from_pretrained("Syamchand/contract-nli-bert")
35
  models["nli_model"] = AutoModelForSequenceClassification.from_pretrained(
@@ -41,16 +115,7 @@ async def lifespan(app: FastAPI):
41
 
42
  # 3. Clause risk classifier
43
  print("Loading clause risk classifier...")
44
- try:
45
- # After editing tokenizer_config.json this will work directly
46
- models["risk_tokenizer"] = AutoTokenizer.from_pretrained(
47
- "Syamchand/clause_risk_classifier"
48
- )
49
- except ValueError:
50
- # Fallback: direct class if config still has old tokenizer_class
51
- models["risk_tokenizer"] = ModernBertTokenizerFast.from_pretrained(
52
- "Syamchand/clause_risk_classifier"
53
- )
54
  models["risk_model"] = AutoModelForSequenceClassification.from_pretrained(
55
  "Syamchand/clause_risk_classifier"
56
  )
@@ -61,29 +126,24 @@ async def lifespan(app: FastAPI):
61
  # 4. Legal BERT embeddings
62
  print("Loading legal BERT embeddings model...")
63
  models["emb_tokenizer"] = AutoTokenizer.from_pretrained("nlpaueb/bert-base-uncased-contracts")
64
- models["emb_model"] = AutoModel.from_pretrained(
65
- "nlpaueb/bert-base-uncased-contracts"
66
- )
67
  models["emb_model"].eval()
68
  print("✓ legal BERT loaded")
69
 
70
- # 5. Semantic chunker
71
  print("Loading semantic chunker model...")
72
- models["chunker"] = SentenceTransformer(
73
- "Raubachm/sentence-transformers-semantic-chunker",
74
- device="cpu"
75
- )
76
  print("✓ semantic chunker loaded")
77
 
78
  print("All models ready!")
79
  yield
80
-
81
  models.clear()
82
- torch.cuda.empty_cache()
83
 
84
 
85
  app = FastAPI(lifespan=lifespan)
86
 
 
87
  # ---------- Schemas ----------
88
  class TextRequest(BaseModel):
89
  text: str
@@ -97,8 +157,9 @@ class EmbeddingRequest(BaseModel):
97
 
98
  class ChunkRequest(BaseModel):
99
  text: str
100
- threshold: float = 0.7
101
- max_chunk_tokens: int = 256
 
102
 
103
  class ClassificationResult(BaseModel):
104
  label: str
@@ -111,6 +172,7 @@ class ChunkResult(BaseModel):
111
  chunks: List[str]
112
 
113
 
 
114
  @app.get("/health")
115
  def health():
116
  return {"status": "ok"}
@@ -120,20 +182,18 @@ def health():
120
  def predict_contracts_clauses(req: TextRequest):
121
  preds = models["contracts_clauses"]([req.text])
122
  label_id = int(preds[0])
123
- if hasattr(models["contracts_clauses"], "labels"):
124
- label = models["contracts_clauses"].labels[label_id]
125
- else:
126
- label = f"class_{label_id}"
127
  return ClassificationResult(label=label, score=1.0)
128
 
129
 
130
  @app.post("/predict/nli", response_model=ClassificationResult)
131
  def predict_nli(req: PairRequest):
132
- tok = models["nli_tokenizer"]
133
- model = models["nli_model"]
134
- inputs = tok(req.premise, req.hypothesis, return_tensors="pt", truncation=True)
135
  with torch.no_grad():
136
- logits = model(**inputs).logits
137
  probs = torch.nn.functional.softmax(logits, dim=-1)
138
  class_id = torch.argmax(probs, dim=-1).item()
139
  return ClassificationResult(
@@ -144,11 +204,11 @@ def predict_nli(req: PairRequest):
144
 
145
  @app.post("/predict/risk", response_model=ClassificationResult)
146
  def predict_risk(req: TextRequest):
147
- tok = models["risk_tokenizer"]
148
- model = models["risk_model"]
149
- inputs = tok(req.text, return_tensors="pt", truncation=True, max_length=512)
150
  with torch.no_grad():
151
- logits = model(**inputs).logits
152
  probs = torch.nn.functional.softmax(logits, dim=-1)
153
  class_id = torch.argmax(probs, dim=-1).item()
154
  return ClassificationResult(
@@ -159,43 +219,27 @@ def predict_risk(req: TextRequest):
159
 
160
  def mean_pooling(model_output, attention_mask):
161
  token_embeddings = model_output.last_hidden_state
162
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
163
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
164
 
165
  @app.post("/predict/embeddings", response_model=EmbeddingResult)
166
  def get_embeddings(req: EmbeddingRequest):
167
- tok = models["emb_tokenizer"]
168
- model = models["emb_model"]
169
- encoded = tok(req.texts, padding=True, truncation=True, return_tensors="pt")
170
  with torch.no_grad():
171
- outputs = model(**encoded)
172
  embeddings = mean_pooling(outputs, encoded["attention_mask"])
173
  return EmbeddingResult(embeddings=embeddings.tolist())
174
 
175
 
176
  @app.post("/predict/semantic_chunks", response_model=ChunkResult)
177
  def semantic_chunking(req: ChunkRequest):
178
- model = models["chunker"]
179
- sentences = [s.strip() for s in req.text.replace('\n', ' ').split('.') if s.strip()]
180
- if not sentences:
181
- return ChunkResult(chunks=[req.text])
182
-
183
- sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
184
-
185
- chunks = []
186
- current_chunk = [sentences[0]]
187
- current_emb = sentence_embeddings[0]
188
-
189
- for i in range(1, len(sentences)):
190
- sim = torch.nn.functional.cosine_similarity(current_emb, sentence_embeddings[i], dim=0).item()
191
- if sim >= req.threshold:
192
- current_chunk.append(sentences[i])
193
- chunk_embs = torch.stack([sentence_embeddings[j] for j in range(i - len(current_chunk) + 1, i + 1)])
194
- current_emb = torch.mean(chunk_embs, dim=0)
195
- else:
196
- chunks.append('. '.join(current_chunk) + '.')
197
- current_chunk = [sentences[i]]
198
- current_emb = sentence_embeddings[i]
199
- if current_chunk:
200
- chunks.append('. '.join(current_chunk) + '.')
201
  return ChunkResult(chunks=chunks)
 
1
  import torch
2
+ import numpy as np
 
 
 
3
  from contextlib import asynccontextmanager
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
7
  from sentence_transformers import SentenceTransformer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
  from setfit import SetFitModel
 
 
 
 
 
10
  from typing import List
11
 
12
  models = {}
13
 
14
+
15
+ # ---------- TextChunker (from Raubachm/sentence-transformers-semantic-chunker) ----------
16
+ class TextChunker:
17
+ def __init__(self, st_model: SentenceTransformer):
18
+ self.model = st_model
19
+
20
+ def chunk(self, text: str, context_window: int = 1,
21
+ percentile_threshold: float = 95, min_chunk_size: int = 3) -> List[str]:
22
+ import nltk
23
+ nltk.download("punkt", quiet=True)
24
+ nltk.download("punkt_tab", quiet=True)
25
+ from nltk.tokenize import sent_tokenize
26
+
27
+ sentences = sent_tokenize(text)
28
+ if not sentences:
29
+ return [text]
30
+
31
+ contextualized = self._add_context(sentences, context_window)
32
+ embeddings = self.model.encode(contextualized)
33
+
34
+ distances = self._calculate_distances(embeddings)
35
+ if not distances:
36
+ return [text]
37
+
38
+ breakpoints = self._identify_breakpoints(distances, percentile_threshold)
39
+ initial_chunks = self._create_chunks(sentences, breakpoints)
40
+
41
+ chunk_embeddings = self.model.encode(initial_chunks)
42
+ final_chunks = self._merge_small_chunks(initial_chunks, chunk_embeddings, min_chunk_size)
43
+ return final_chunks
44
+
45
+ def _add_context(self, sentences, window_size):
46
+ result = []
47
+ for i in range(len(sentences)):
48
+ start = max(0, i - window_size)
49
+ end = min(len(sentences), i + window_size + 1)
50
+ result.append(" ".join(sentences[start:end]))
51
+ return result
52
+
53
+ def _calculate_distances(self, embeddings):
54
+ distances = []
55
+ for i in range(len(embeddings) - 1):
56
+ sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
57
+ distances.append(1 - sim)
58
+ return distances
59
+
60
+ def _identify_breakpoints(self, distances, threshold_percentile):
61
+ threshold = np.percentile(distances, threshold_percentile)
62
+ return [i for i, d in enumerate(distances) if d > threshold]
63
+
64
+ def _create_chunks(self, sentences, breakpoints):
65
+ chunks, start = [], 0
66
+ for bp in breakpoints:
67
+ chunks.append(" ".join(sentences[start:bp + 1]))
68
+ start = bp + 1
69
+ chunks.append(" ".join(sentences[start:]))
70
+ return chunks
71
+
72
+ def _merge_small_chunks(self, chunks, embeddings, min_size):
73
+ if len(chunks) <= 1:
74
+ return chunks
75
+ final_chunks = [chunks[0]]
76
+ merged_embeddings = [embeddings[0]]
77
+ for i in range(1, len(chunks) - 1):
78
+ if len(chunks[i].split(". ")) < min_size:
79
+ prev_sim = cosine_similarity([embeddings[i]], [merged_embeddings[-1]])[0][0]
80
+ next_sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
81
+ if prev_sim > next_sim:
82
+ final_chunks[-1] = f"{final_chunks[-1]} {chunks[i]}"
83
+ merged_embeddings[-1] = (merged_embeddings[-1] + embeddings[i]) / 2
84
+ else:
85
+ chunks[i + 1] = f"{chunks[i]} {chunks[i + 1]}"
86
+ embeddings[i + 1] = (embeddings[i] + embeddings[i + 1]) / 2
87
+ else:
88
+ final_chunks.append(chunks[i])
89
+ merged_embeddings.append(embeddings[i])
90
+ final_chunks.append(chunks[-1])
91
+ return final_chunks
92
+
93
+
94
+ # ---------- Lifespan ----------
95
  @asynccontextmanager
96
  async def lifespan(app: FastAPI):
97
  print("Loading models...")
 
103
  )
104
  print("✓ contracts_clauses loaded")
105
 
106
+ # 2. Contract NLI
107
  print("Loading contract NLI model...")
108
  models["nli_tokenizer"] = AutoTokenizer.from_pretrained("Syamchand/contract-nli-bert")
109
  models["nli_model"] = AutoModelForSequenceClassification.from_pretrained(
 
115
 
116
  # 3. Clause risk classifier
117
  print("Loading clause risk classifier...")
118
+ models["risk_tokenizer"] = AutoTokenizer.from_pretrained("Syamchand/clause_risk_classifier")
 
 
 
 
 
 
 
 
 
119
  models["risk_model"] = AutoModelForSequenceClassification.from_pretrained(
120
  "Syamchand/clause_risk_classifier"
121
  )
 
126
  # 4. Legal BERT embeddings
127
  print("Loading legal BERT embeddings model...")
128
  models["emb_tokenizer"] = AutoTokenizer.from_pretrained("nlpaueb/bert-base-uncased-contracts")
129
+ models["emb_model"] = AutoModel.from_pretrained("nlpaueb/bert-base-uncased-contracts")
 
 
130
  models["emb_model"].eval()
131
  print("✓ legal BERT loaded")
132
 
133
+ # 5. Semantic chunker — load the backbone model specified in the Raubachm model card
134
  print("Loading semantic chunker model...")
135
+ st_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v1", device="cpu")
136
+ models["chunker"] = TextChunker(st_model)
 
 
137
  print("✓ semantic chunker loaded")
138
 
139
  print("All models ready!")
140
  yield
 
141
  models.clear()
 
142
 
143
 
144
  app = FastAPI(lifespan=lifespan)
145
 
146
+
147
  # ---------- Schemas ----------
148
  class TextRequest(BaseModel):
149
  text: str
 
157
 
158
  class ChunkRequest(BaseModel):
159
  text: str
160
+ percentile_threshold: float = 95.0
161
+ context_window: int = 1
162
+ min_chunk_size: int = 3
163
 
164
  class ClassificationResult(BaseModel):
165
  label: str
 
172
  chunks: List[str]
173
 
174
 
175
+ # ---------- Endpoints ----------
176
  @app.get("/health")
177
  def health():
178
  return {"status": "ok"}
 
182
  def predict_contracts_clauses(req: TextRequest):
183
  preds = models["contracts_clauses"]([req.text])
184
  label_id = int(preds[0])
185
+ label = models["contracts_clauses"].labels[label_id] if hasattr(
186
+ models["contracts_clauses"], "labels") else f"class_{label_id}"
 
 
187
  return ClassificationResult(label=label, score=1.0)
188
 
189
 
190
  @app.post("/predict/nli", response_model=ClassificationResult)
191
  def predict_nli(req: PairRequest):
192
+ inputs = models["nli_tokenizer"](
193
+ req.premise, req.hypothesis, return_tensors="pt", truncation=True
194
+ )
195
  with torch.no_grad():
196
+ logits = models["nli_model"](**inputs).logits
197
  probs = torch.nn.functional.softmax(logits, dim=-1)
198
  class_id = torch.argmax(probs, dim=-1).item()
199
  return ClassificationResult(
 
204
 
205
  @app.post("/predict/risk", response_model=ClassificationResult)
206
  def predict_risk(req: TextRequest):
207
+ inputs = models["risk_tokenizer"](
208
+ req.text, return_tensors="pt", truncation=True, max_length=512
209
+ )
210
  with torch.no_grad():
211
+ logits = models["risk_model"](**inputs).logits
212
  probs = torch.nn.functional.softmax(logits, dim=-1)
213
  class_id = torch.argmax(probs, dim=-1).item()
214
  return ClassificationResult(
 
219
 
220
  def mean_pooling(model_output, attention_mask):
221
  token_embeddings = model_output.last_hidden_state
222
+ mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
223
+ return torch.sum(token_embeddings * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
224
+
225
 
226
  @app.post("/predict/embeddings", response_model=EmbeddingResult)
227
  def get_embeddings(req: EmbeddingRequest):
228
+ encoded = models["emb_tokenizer"](
229
+ req.texts, padding=True, truncation=True, return_tensors="pt"
230
+ )
231
  with torch.no_grad():
232
+ outputs = models["emb_model"](**encoded)
233
  embeddings = mean_pooling(outputs, encoded["attention_mask"])
234
  return EmbeddingResult(embeddings=embeddings.tolist())
235
 
236
 
237
  @app.post("/predict/semantic_chunks", response_model=ChunkResult)
238
  def semantic_chunking(req: ChunkRequest):
239
+ chunks = models["chunker"].chunk(
240
+ text=req.text,
241
+ context_window=req.context_window,
242
+ percentile_threshold=req.percentile_threshold,
243
+ min_chunk_size=req.min_chunk_size
244
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  return ChunkResult(chunks=chunks)