red_ml-models / app.py
Syamchand's picture
Update app.py
8fcdf14 verified
import torch
import numpy as np
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from setfit import SetFitModel
from gliner import GLiNER
from typing import List, Optional
import os
models = {}
# ---------- TextChunker (from Raubachm/sentence-transformers-semantic-chunker) ----------
class TextChunker:
def __init__(self, st_model: SentenceTransformer):
self.model = st_model
def chunk(self, text: str, context_window: int = 1,
percentile_threshold: float = 95, min_chunk_size: int = 3) -> List[str]:
import nltk
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
from nltk.tokenize import sent_tokenize
sentences = sent_tokenize(text)
if not sentences:
return [text]
contextualized = self._add_context(sentences, context_window)
embeddings = self.model.encode(contextualized)
distances = self._calculate_distances(embeddings)
if not distances:
return [text]
breakpoints = self._identify_breakpoints(distances, percentile_threshold)
initial_chunks = self._create_chunks(sentences, breakpoints)
chunk_embeddings = self.model.encode(initial_chunks)
final_chunks = self._merge_small_chunks(initial_chunks, chunk_embeddings, min_chunk_size)
return final_chunks
def _add_context(self, sentences, window_size):
result = []
for i in range(len(sentences)):
start = max(0, i - window_size)
end = min(len(sentences), i + window_size + 1)
result.append(" ".join(sentences[start:end]))
return result
def _calculate_distances(self, embeddings):
distances = []
for i in range(len(embeddings) - 1):
sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
distances.append(1 - sim)
return distances
def _identify_breakpoints(self, distances, threshold_percentile):
threshold = np.percentile(distances, threshold_percentile)
return [i for i, d in enumerate(distances) if d > threshold]
def _create_chunks(self, sentences, breakpoints):
chunks, start = [], 0
for bp in breakpoints:
chunks.append(" ".join(sentences[start:bp + 1]))
start = bp + 1
chunks.append(" ".join(sentences[start:]))
return chunks
def _merge_small_chunks(self, chunks, embeddings, min_size):
if len(chunks) <= 1:
return chunks
final_chunks = [chunks[0]]
merged_embeddings = [embeddings[0]]
for i in range(1, len(chunks) - 1):
if len(chunks[i].split(". ")) < min_size:
prev_sim = cosine_similarity([embeddings[i]], [merged_embeddings[-1]])[0][0]
next_sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
if prev_sim > next_sim:
final_chunks[-1] = f"{final_chunks[-1]} {chunks[i]}"
merged_embeddings[-1] = (merged_embeddings[-1] + embeddings[i]) / 2
else:
chunks[i + 1] = f"{chunks[i]} {chunks[i + 1]}"
embeddings[i + 1] = (embeddings[i] + embeddings[i + 1]) / 2
else:
final_chunks.append(chunks[i])
merged_embeddings.append(embeddings[i])
final_chunks.append(chunks[-1])
return final_chunks
# ---------- Lifespan ----------
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading models...")
# Helper to load quantized BERT models with optimum-quanto
def load_quantized_bert(model_id, num_labels=None):
from transformers import QuantoConfig
quant_config = QuantoConfig(weights="int8")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=num_labels,
quantization_config=quant_config, # will be ignored for AutoModel (num_labels=None)
ignore_mismatched_sizes=True
) if num_labels else AutoModel.from_pretrained(
model_id,
quantization_config=quant_config
)
model.eval()
return model, tokenizer
# 1. SetFit contracts clauses
print("Loading SetFit contracts clauses model...")
models["contracts_clauses"] = SetFitModel.from_pretrained(
"scholarly360/setfit-contracts-clauses"
)
print("✓ contracts_clauses loaded")
# 2. Contract NLI
print("Loading contract NLI model(int8)...")
nli_model,nli_tokenizer=load_quantized_bert("Syamchand/contract-nli-bert", num_labels=3)
models["nli_tokenizer"] = nli_tokenizer
models["nli_model"] = nli_model
models["nli_id2label"] = {0: "entailment", 1: "neutral", 2: "contradiction"}
print("✓ contract-nli loaded (int8) ")
# 3. Clause risk classifier
print("Loading clause risk classifier(int8)...")
risk_model, risk_tokenizer = load_quantized_bert("Syamchand/clause_risk_classifier", num_labels=3)
models["risk_tokenizer"] = risk_tokenizer
models["risk_model"] = risk_model
#models["risk_model"].eval()
models["risk_id2label"] = {0: "low", 1: "medium", 2: "high"}
print("✓ clause_risk_classifier loaded (int8)")
# 4. Legal BERT embeddings
print("Loading legal BERT embeddings model...")
emb_model, emb_tokenizer = load_quantized_bert("nlpaueb/bert-base-uncased-contracts", num_labels=None)
models["emb_tokenizer"] = emb_tokenizer
models["emb_model"] = emb_model
#models["emb_model"].eval()
print("✓ legal BERT loaded(int8)")
# 4b. Text explanation / summarization model (Flan‑T5 small, Float16)
print("Loading text explanation/summarization model...")
explain_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
explain_model = AutoModelForSeq2SeqLM.from_pretrained(
"google/flan-t5-small",
torch_dtype=torch.float16 # half‑precision, good tradeoff
).eval()
models["explain_model"] = explain_model
models["explain_tokenizer"] = explain_tokenizer
print("✓ explain/summarize model loaded (flan-t5-small, float16)")
# 5. Semantic chunker — load the backbone model specified in the Raubachm model card
print("Loading semantic chunker model...")
st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
models["chunker"] = TextChunker(st_model)
print("✓ semantic chunker loaded")
# 6. NuNER_Zero NER model
print("Loading NuNER_Zero NER model(int8)...")
models["ner"] = GLiNER.from_pretrained("numind/NuNER_Zero",quantize=True)
print("✓ NuNER_Zero loaded (int8)")
print("All models ready!")
yield
models.clear()
app = FastAPI(lifespan=lifespan)
# ---------- Schemas ----------
class TextRequest(BaseModel):
text: str
class PairRequest(BaseModel):
premise: str
hypothesis: str
class EmbeddingRequest(BaseModel):
texts: List[str]
class ChunkRequest(BaseModel):
text: str
percentile_threshold: float = 95.0
context_window: int = 1
min_chunk_size: int = 3
class ExplanationRequest(BaseModel):
text: str
mode: str = "explain" # "summarize" or "explain"
class ClassificationResult(BaseModel):
label: str
score: float
class EmbeddingResult(BaseModel):
embeddings: List[List[float]]
class ChunkResult(BaseModel):
chunks: List[str]
class NERRequest(BaseModel):
text: str
entity_types: Optional[List[str]] = None
class Entity(BaseModel):
text: str
label: str
score: float
start: int
end: int
class NERResult(BaseModel):
entities: List[Entity]
# ---------- Endpoints ----------
@app.get("/health")
def health():
return {"status": "ok",}
@app.get("/memory")
def container_memory():
# Try cgroup v2 first (most common on HF Spaces)
if os.path.exists("/sys/fs/cgroup/memory.current"):
with open("/sys/fs/cgroup/memory.current") as f:
usage = int(f.read().strip())
with open("/sys/fs/cgroup/memory.max") as f:
limit_str = f.read().strip()
limit = int(limit_str) if limit_str != "max" else None
# Fallback to cgroup v1
elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"):
with open("/sys/fs/cgroup/memory/memory.usage_in_bytes") as f:
usage = int(f.read().strip())
with open("/sys/fs/cgroup/memory/memory.limit_in_bytes") as f:
limit = int(f.read().strip())
else:
return {"error": "Cannot read container memory"}
if limit is None:
return {"usage_mb": round(usage / (1024*1024), 2), "limit_mb": "unlimited", "percent": "unknown"}
return {
"usage_mb": round(usage / (1024*1024), 2),
"limit_mb": round(limit / (1024*1024), 2),
"percent": round(usage / limit * 100, 1)
}
@app.post("/predict/contracts_clauses", response_model=ClassificationResult)
def predict_contracts_clauses(req: TextRequest):
model = models["contracts_clauses"]
# The SetFit model predicts labels directly (no integer conversion needed)
preds = model.predict([req.text])
label = preds[0] # Already a string like 'terms'
# Try to get a confidence score using predict_proba if available
score = 1.0
if hasattr(model, "predict_proba"):
try:
probs = model.predict_proba([req.text])[0]
# model.labels stores the label strings in the order expected by predict_proba
if hasattr(model, "labels") and model.labels is not None:
# Find the index of the predicted label
if label in model.labels:
idx = model.labels.index(label)
score = probs[idx]
else:
score = max(probs)
else:
score = max(probs)
except Exception:
score = 1.0
return ClassificationResult(label=label, score=round(float(score), 4))
@app.post("/predict/nli", response_model=ClassificationResult)
def predict_nli(req: PairRequest):
inputs = models["nli_tokenizer"](
req.premise, req.hypothesis, return_tensors="pt", truncation=True
)
with torch.no_grad():
logits = models["nli_model"](**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
class_id = torch.argmax(probs, dim=-1).item()
return ClassificationResult(
label=models["nli_id2label"][class_id],
score=round(probs[0][class_id].item(), 4)
)
@app.post("/predict/risk", response_model=ClassificationResult)
def predict_risk(req: TextRequest):
inputs = models["risk_tokenizer"](
req.text, return_tensors="pt", truncation=True, max_length=512
)
with torch.no_grad():
logits = models["risk_model"](**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
class_id = torch.argmax(probs, dim=-1).item()
return ClassificationResult(
label=models["risk_id2label"][class_id],
score=round(probs[0][class_id].item(), 4)
)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
@app.post("/predict/embeddings", response_model=EmbeddingResult)
def get_embeddings(req: EmbeddingRequest):
encoded = models["emb_tokenizer"](
req.texts, padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
outputs = models["emb_model"](**encoded)
embeddings = mean_pooling(outputs, encoded["attention_mask"])
return EmbeddingResult(embeddings=embeddings.tolist())
@app.post("/predict/semantic_chunks", response_model=ChunkResult)
def semantic_chunking(req: ChunkRequest):
chunks = models["chunker"].chunk(
text=req.text,
context_window=req.context_window,
percentile_threshold=req.percentile_threshold,
min_chunk_size=req.min_chunk_size
)
return ChunkResult(chunks=chunks)
@app.post("/predict/explain")
def explain_text(req: ExplanationRequest):
tokenizer = models["explain_tokenizer"]
model = models["explain_model"]
# FLAN-T5 models fine-tuned on summarization require the "summarize: " prefix
input_text = f"summarize: {req.text}"
# If the user asks for an 'explain', we can still frame it as an intensive summary
if req.mode == "explain":
input_text = f"summarize in detail: {req.text}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
num_beams=5,
length_penalty=2.0, # Encourage longer generation
no_repeat_ngram_size=3, # Prevent repetition
early_stopping=True
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"mode": req.mode, "generated_text": result}
@app.post("/predict/ner", response_model=NERResult)
def predict_ner(req: NERRequest):
# Default entity types suitable for freelancer contracts
default_types = [
"freelancer", "company", "contract value", "governing law",
"jurisdiction", "duration", "payment terms", "termination",
"liability", "intellectual property"
]
labels = req.entity_types if req.entity_types else default_types
# GLiNER expects lowercase labels for optimal performance
labels = [l.lower() for l in labels]
raw_entities = models["ner"].predict_entities(req.text, labels)
return NERResult(entities=[Entity(**ent) for ent in raw_entities])