import torch import numpy as np from contextlib import asynccontextmanager from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, AutoModelForSeq2SeqLM from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from setfit import SetFitModel from gliner import GLiNER from typing import List, Optional import os models = {} # ---------- TextChunker (from Raubachm/sentence-transformers-semantic-chunker) ---------- class TextChunker: def __init__(self, st_model: SentenceTransformer): self.model = st_model def chunk(self, text: str, context_window: int = 1, percentile_threshold: float = 95, min_chunk_size: int = 3) -> List[str]: import nltk nltk.download("punkt", quiet=True) nltk.download("punkt_tab", quiet=True) from nltk.tokenize import sent_tokenize sentences = sent_tokenize(text) if not sentences: return [text] contextualized = self._add_context(sentences, context_window) embeddings = self.model.encode(contextualized) distances = self._calculate_distances(embeddings) if not distances: return [text] breakpoints = self._identify_breakpoints(distances, percentile_threshold) initial_chunks = self._create_chunks(sentences, breakpoints) chunk_embeddings = self.model.encode(initial_chunks) final_chunks = self._merge_small_chunks(initial_chunks, chunk_embeddings, min_chunk_size) return final_chunks def _add_context(self, sentences, window_size): result = [] for i in range(len(sentences)): start = max(0, i - window_size) end = min(len(sentences), i + window_size + 1) result.append(" ".join(sentences[start:end])) return result def _calculate_distances(self, embeddings): distances = [] for i in range(len(embeddings) - 1): sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0] distances.append(1 - sim) return distances def _identify_breakpoints(self, distances, threshold_percentile): threshold = np.percentile(distances, threshold_percentile) return [i for i, d in enumerate(distances) if d > threshold] def _create_chunks(self, sentences, breakpoints): chunks, start = [], 0 for bp in breakpoints: chunks.append(" ".join(sentences[start:bp + 1])) start = bp + 1 chunks.append(" ".join(sentences[start:])) return chunks def _merge_small_chunks(self, chunks, embeddings, min_size): if len(chunks) <= 1: return chunks final_chunks = [chunks[0]] merged_embeddings = [embeddings[0]] for i in range(1, len(chunks) - 1): if len(chunks[i].split(". ")) < min_size: prev_sim = cosine_similarity([embeddings[i]], [merged_embeddings[-1]])[0][0] next_sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0] if prev_sim > next_sim: final_chunks[-1] = f"{final_chunks[-1]} {chunks[i]}" merged_embeddings[-1] = (merged_embeddings[-1] + embeddings[i]) / 2 else: chunks[i + 1] = f"{chunks[i]} {chunks[i + 1]}" embeddings[i + 1] = (embeddings[i] + embeddings[i + 1]) / 2 else: final_chunks.append(chunks[i]) merged_embeddings.append(embeddings[i]) final_chunks.append(chunks[-1]) return final_chunks # ---------- Lifespan ---------- @asynccontextmanager async def lifespan(app: FastAPI): print("Loading models...") # Helper to load quantized BERT models with optimum-quanto def load_quantized_bert(model_id, num_labels=None): from transformers import QuantoConfig quant_config = QuantoConfig(weights="int8") tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=num_labels, quantization_config=quant_config, # will be ignored for AutoModel (num_labels=None) ignore_mismatched_sizes=True ) if num_labels else AutoModel.from_pretrained( model_id, quantization_config=quant_config ) model.eval() return model, tokenizer # 1. SetFit contracts clauses print("Loading SetFit contracts clauses model...") models["contracts_clauses"] = SetFitModel.from_pretrained( "scholarly360/setfit-contracts-clauses" ) print("✓ contracts_clauses loaded") # 2. Contract NLI print("Loading contract NLI model(int8)...") nli_model,nli_tokenizer=load_quantized_bert("Syamchand/contract-nli-bert", num_labels=3) models["nli_tokenizer"] = nli_tokenizer models["nli_model"] = nli_model models["nli_id2label"] = {0: "entailment", 1: "neutral", 2: "contradiction"} print("✓ contract-nli loaded (int8) ") # 3. Clause risk classifier print("Loading clause risk classifier(int8)...") risk_model, risk_tokenizer = load_quantized_bert("Syamchand/clause_risk_classifier", num_labels=3) models["risk_tokenizer"] = risk_tokenizer models["risk_model"] = risk_model #models["risk_model"].eval() models["risk_id2label"] = {0: "low", 1: "medium", 2: "high"} print("✓ clause_risk_classifier loaded (int8)") # 4. Legal BERT embeddings print("Loading legal BERT embeddings model...") emb_model, emb_tokenizer = load_quantized_bert("nlpaueb/bert-base-uncased-contracts", num_labels=None) models["emb_tokenizer"] = emb_tokenizer models["emb_model"] = emb_model #models["emb_model"].eval() print("✓ legal BERT loaded(int8)") # 4b. Text explanation / summarization model (Flan‑T5 small, Float16) print("Loading text explanation/summarization model...") explain_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") explain_model = AutoModelForSeq2SeqLM.from_pretrained( "google/flan-t5-small", torch_dtype=torch.float16 # half‑precision, good tradeoff ).eval() models["explain_model"] = explain_model models["explain_tokenizer"] = explain_tokenizer print("✓ explain/summarize model loaded (flan-t5-small, float16)") # 5. Semantic chunker — load the backbone model specified in the Raubachm model card print("Loading semantic chunker model...") st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu") models["chunker"] = TextChunker(st_model) print("✓ semantic chunker loaded") # 6. NuNER_Zero NER model print("Loading NuNER_Zero NER model(int8)...") models["ner"] = GLiNER.from_pretrained("numind/NuNER_Zero",quantize=True) print("✓ NuNER_Zero loaded (int8)") print("All models ready!") yield models.clear() app = FastAPI(lifespan=lifespan) # ---------- Schemas ---------- class TextRequest(BaseModel): text: str class PairRequest(BaseModel): premise: str hypothesis: str class EmbeddingRequest(BaseModel): texts: List[str] class ChunkRequest(BaseModel): text: str percentile_threshold: float = 95.0 context_window: int = 1 min_chunk_size: int = 3 class ExplanationRequest(BaseModel): text: str mode: str = "explain" # "summarize" or "explain" class ClassificationResult(BaseModel): label: str score: float class EmbeddingResult(BaseModel): embeddings: List[List[float]] class ChunkResult(BaseModel): chunks: List[str] class NERRequest(BaseModel): text: str entity_types: Optional[List[str]] = None class Entity(BaseModel): text: str label: str score: float start: int end: int class NERResult(BaseModel): entities: List[Entity] # ---------- Endpoints ---------- @app.get("/health") def health(): return {"status": "ok",} @app.get("/memory") def container_memory(): # Try cgroup v2 first (most common on HF Spaces) if os.path.exists("/sys/fs/cgroup/memory.current"): with open("/sys/fs/cgroup/memory.current") as f: usage = int(f.read().strip()) with open("/sys/fs/cgroup/memory.max") as f: limit_str = f.read().strip() limit = int(limit_str) if limit_str != "max" else None # Fallback to cgroup v1 elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"): with open("/sys/fs/cgroup/memory/memory.usage_in_bytes") as f: usage = int(f.read().strip()) with open("/sys/fs/cgroup/memory/memory.limit_in_bytes") as f: limit = int(f.read().strip()) else: return {"error": "Cannot read container memory"} if limit is None: return {"usage_mb": round(usage / (1024*1024), 2), "limit_mb": "unlimited", "percent": "unknown"} return { "usage_mb": round(usage / (1024*1024), 2), "limit_mb": round(limit / (1024*1024), 2), "percent": round(usage / limit * 100, 1) } @app.post("/predict/contracts_clauses", response_model=ClassificationResult) def predict_contracts_clauses(req: TextRequest): model = models["contracts_clauses"] # The SetFit model predicts labels directly (no integer conversion needed) preds = model.predict([req.text]) label = preds[0] # Already a string like 'terms' # Try to get a confidence score using predict_proba if available score = 1.0 if hasattr(model, "predict_proba"): try: probs = model.predict_proba([req.text])[0] # model.labels stores the label strings in the order expected by predict_proba if hasattr(model, "labels") and model.labels is not None: # Find the index of the predicted label if label in model.labels: idx = model.labels.index(label) score = probs[idx] else: score = max(probs) else: score = max(probs) except Exception: score = 1.0 return ClassificationResult(label=label, score=round(float(score), 4)) @app.post("/predict/nli", response_model=ClassificationResult) def predict_nli(req: PairRequest): inputs = models["nli_tokenizer"]( req.premise, req.hypothesis, return_tensors="pt", truncation=True ) with torch.no_grad(): logits = models["nli_model"](**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1) class_id = torch.argmax(probs, dim=-1).item() return ClassificationResult( label=models["nli_id2label"][class_id], score=round(probs[0][class_id].item(), 4) ) @app.post("/predict/risk", response_model=ClassificationResult) def predict_risk(req: TextRequest): inputs = models["risk_tokenizer"]( req.text, return_tensors="pt", truncation=True, max_length=512 ) with torch.no_grad(): logits = models["risk_model"](**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1) class_id = torch.argmax(probs, dim=-1).item() return ClassificationResult( label=models["risk_id2label"][class_id], score=round(probs[0][class_id].item(), 4) ) def mean_pooling(model_output, attention_mask): token_embeddings = model_output.last_hidden_state mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9) @app.post("/predict/embeddings", response_model=EmbeddingResult) def get_embeddings(req: EmbeddingRequest): encoded = models["emb_tokenizer"]( req.texts, padding=True, truncation=True, return_tensors="pt" ) with torch.no_grad(): outputs = models["emb_model"](**encoded) embeddings = mean_pooling(outputs, encoded["attention_mask"]) return EmbeddingResult(embeddings=embeddings.tolist()) @app.post("/predict/semantic_chunks", response_model=ChunkResult) def semantic_chunking(req: ChunkRequest): chunks = models["chunker"].chunk( text=req.text, context_window=req.context_window, percentile_threshold=req.percentile_threshold, min_chunk_size=req.min_chunk_size ) return ChunkResult(chunks=chunks) @app.post("/predict/explain") def explain_text(req: ExplanationRequest): tokenizer = models["explain_tokenizer"] model = models["explain_model"] # FLAN-T5 models fine-tuned on summarization require the "summarize: " prefix input_text = f"summarize: {req.text}" # If the user asks for an 'explain', we can still frame it as an intensive summary if req.mode == "explain": input_text = f"summarize in detail: {req.text}" inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, num_beams=5, length_penalty=2.0, # Encourage longer generation no_repeat_ngram_size=3, # Prevent repetition early_stopping=True ) result = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"mode": req.mode, "generated_text": result} @app.post("/predict/ner", response_model=NERResult) def predict_ner(req: NERRequest): # Default entity types suitable for freelancer contracts default_types = [ "freelancer", "company", "contract value", "governing law", "jurisdiction", "duration", "payment terms", "termination", "liability", "intellectual property" ] labels = req.entity_types if req.entity_types else default_types # GLiNER expects lowercase labels for optimal performance labels = [l.lower() for l in labels] raw_entities = models["ner"].predict_entities(req.text, labels) return NERResult(entities=[Entity(**ent) for ent in raw_entities])