Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, AutoModelForSeq2SeqLM | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from setfit import SetFitModel | |
| from gliner import GLiNER | |
| from typing import List, Optional | |
| import os | |
| models = {} | |
| # ---------- TextChunker (from Raubachm/sentence-transformers-semantic-chunker) ---------- | |
| class TextChunker: | |
| def __init__(self, st_model: SentenceTransformer): | |
| self.model = st_model | |
| def chunk(self, text: str, context_window: int = 1, | |
| percentile_threshold: float = 95, min_chunk_size: int = 3) -> List[str]: | |
| import nltk | |
| nltk.download("punkt", quiet=True) | |
| nltk.download("punkt_tab", quiet=True) | |
| from nltk.tokenize import sent_tokenize | |
| sentences = sent_tokenize(text) | |
| if not sentences: | |
| return [text] | |
| contextualized = self._add_context(sentences, context_window) | |
| embeddings = self.model.encode(contextualized) | |
| distances = self._calculate_distances(embeddings) | |
| if not distances: | |
| return [text] | |
| breakpoints = self._identify_breakpoints(distances, percentile_threshold) | |
| initial_chunks = self._create_chunks(sentences, breakpoints) | |
| chunk_embeddings = self.model.encode(initial_chunks) | |
| final_chunks = self._merge_small_chunks(initial_chunks, chunk_embeddings, min_chunk_size) | |
| return final_chunks | |
| def _add_context(self, sentences, window_size): | |
| result = [] | |
| for i in range(len(sentences)): | |
| start = max(0, i - window_size) | |
| end = min(len(sentences), i + window_size + 1) | |
| result.append(" ".join(sentences[start:end])) | |
| return result | |
| def _calculate_distances(self, embeddings): | |
| distances = [] | |
| for i in range(len(embeddings) - 1): | |
| sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0] | |
| distances.append(1 - sim) | |
| return distances | |
| def _identify_breakpoints(self, distances, threshold_percentile): | |
| threshold = np.percentile(distances, threshold_percentile) | |
| return [i for i, d in enumerate(distances) if d > threshold] | |
| def _create_chunks(self, sentences, breakpoints): | |
| chunks, start = [], 0 | |
| for bp in breakpoints: | |
| chunks.append(" ".join(sentences[start:bp + 1])) | |
| start = bp + 1 | |
| chunks.append(" ".join(sentences[start:])) | |
| return chunks | |
| def _merge_small_chunks(self, chunks, embeddings, min_size): | |
| if len(chunks) <= 1: | |
| return chunks | |
| final_chunks = [chunks[0]] | |
| merged_embeddings = [embeddings[0]] | |
| for i in range(1, len(chunks) - 1): | |
| if len(chunks[i].split(". ")) < min_size: | |
| prev_sim = cosine_similarity([embeddings[i]], [merged_embeddings[-1]])[0][0] | |
| next_sim = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0] | |
| if prev_sim > next_sim: | |
| final_chunks[-1] = f"{final_chunks[-1]} {chunks[i]}" | |
| merged_embeddings[-1] = (merged_embeddings[-1] + embeddings[i]) / 2 | |
| else: | |
| chunks[i + 1] = f"{chunks[i]} {chunks[i + 1]}" | |
| embeddings[i + 1] = (embeddings[i] + embeddings[i + 1]) / 2 | |
| else: | |
| final_chunks.append(chunks[i]) | |
| merged_embeddings.append(embeddings[i]) | |
| final_chunks.append(chunks[-1]) | |
| return final_chunks | |
| # ---------- Lifespan ---------- | |
| async def lifespan(app: FastAPI): | |
| print("Loading models...") | |
| # Helper to load quantized BERT models with optimum-quanto | |
| def load_quantized_bert(model_id, num_labels=None): | |
| from transformers import QuantoConfig | |
| quant_config = QuantoConfig(weights="int8") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_id, | |
| num_labels=num_labels, | |
| quantization_config=quant_config, # will be ignored for AutoModel (num_labels=None) | |
| ignore_mismatched_sizes=True | |
| ) if num_labels else AutoModel.from_pretrained( | |
| model_id, | |
| quantization_config=quant_config | |
| ) | |
| model.eval() | |
| return model, tokenizer | |
| # 1. SetFit contracts clauses | |
| print("Loading SetFit contracts clauses model...") | |
| models["contracts_clauses"] = SetFitModel.from_pretrained( | |
| "scholarly360/setfit-contracts-clauses" | |
| ) | |
| print("✓ contracts_clauses loaded") | |
| # 2. Contract NLI | |
| print("Loading contract NLI model(int8)...") | |
| nli_model,nli_tokenizer=load_quantized_bert("Syamchand/contract-nli-bert", num_labels=3) | |
| models["nli_tokenizer"] = nli_tokenizer | |
| models["nli_model"] = nli_model | |
| models["nli_id2label"] = {0: "entailment", 1: "neutral", 2: "contradiction"} | |
| print("✓ contract-nli loaded (int8) ") | |
| # 3. Clause risk classifier | |
| print("Loading clause risk classifier(int8)...") | |
| risk_model, risk_tokenizer = load_quantized_bert("Syamchand/clause_risk_classifier", num_labels=3) | |
| models["risk_tokenizer"] = risk_tokenizer | |
| models["risk_model"] = risk_model | |
| #models["risk_model"].eval() | |
| models["risk_id2label"] = {0: "low", 1: "medium", 2: "high"} | |
| print("✓ clause_risk_classifier loaded (int8)") | |
| # 4. Legal BERT embeddings | |
| print("Loading legal BERT embeddings model...") | |
| emb_model, emb_tokenizer = load_quantized_bert("nlpaueb/bert-base-uncased-contracts", num_labels=None) | |
| models["emb_tokenizer"] = emb_tokenizer | |
| models["emb_model"] = emb_model | |
| #models["emb_model"].eval() | |
| print("✓ legal BERT loaded(int8)") | |
| # 4b. Text explanation / summarization model (Flan‑T5 small, Float16) | |
| print("Loading text explanation/summarization model...") | |
| explain_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") | |
| explain_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "google/flan-t5-small", | |
| torch_dtype=torch.float16 # half‑precision, good tradeoff | |
| ).eval() | |
| models["explain_model"] = explain_model | |
| models["explain_tokenizer"] = explain_tokenizer | |
| print("✓ explain/summarize model loaded (flan-t5-small, float16)") | |
| # 5. Semantic chunker — load the backbone model specified in the Raubachm model card | |
| print("Loading semantic chunker model...") | |
| st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu") | |
| models["chunker"] = TextChunker(st_model) | |
| print("✓ semantic chunker loaded") | |
| # 6. NuNER_Zero NER model | |
| print("Loading NuNER_Zero NER model(int8)...") | |
| models["ner"] = GLiNER.from_pretrained("numind/NuNER_Zero",quantize=True) | |
| print("✓ NuNER_Zero loaded (int8)") | |
| print("All models ready!") | |
| yield | |
| models.clear() | |
| app = FastAPI(lifespan=lifespan) | |
| # ---------- Schemas ---------- | |
| class TextRequest(BaseModel): | |
| text: str | |
| class PairRequest(BaseModel): | |
| premise: str | |
| hypothesis: str | |
| class EmbeddingRequest(BaseModel): | |
| texts: List[str] | |
| class ChunkRequest(BaseModel): | |
| text: str | |
| percentile_threshold: float = 95.0 | |
| context_window: int = 1 | |
| min_chunk_size: int = 3 | |
| class ExplanationRequest(BaseModel): | |
| text: str | |
| mode: str = "explain" # "summarize" or "explain" | |
| class ClassificationResult(BaseModel): | |
| label: str | |
| score: float | |
| class EmbeddingResult(BaseModel): | |
| embeddings: List[List[float]] | |
| class ChunkResult(BaseModel): | |
| chunks: List[str] | |
| class NERRequest(BaseModel): | |
| text: str | |
| entity_types: Optional[List[str]] = None | |
| class Entity(BaseModel): | |
| text: str | |
| label: str | |
| score: float | |
| start: int | |
| end: int | |
| class NERResult(BaseModel): | |
| entities: List[Entity] | |
| # ---------- Endpoints ---------- | |
| def health(): | |
| return {"status": "ok",} | |
| def container_memory(): | |
| # Try cgroup v2 first (most common on HF Spaces) | |
| if os.path.exists("/sys/fs/cgroup/memory.current"): | |
| with open("/sys/fs/cgroup/memory.current") as f: | |
| usage = int(f.read().strip()) | |
| with open("/sys/fs/cgroup/memory.max") as f: | |
| limit_str = f.read().strip() | |
| limit = int(limit_str) if limit_str != "max" else None | |
| # Fallback to cgroup v1 | |
| elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"): | |
| with open("/sys/fs/cgroup/memory/memory.usage_in_bytes") as f: | |
| usage = int(f.read().strip()) | |
| with open("/sys/fs/cgroup/memory/memory.limit_in_bytes") as f: | |
| limit = int(f.read().strip()) | |
| else: | |
| return {"error": "Cannot read container memory"} | |
| if limit is None: | |
| return {"usage_mb": round(usage / (1024*1024), 2), "limit_mb": "unlimited", "percent": "unknown"} | |
| return { | |
| "usage_mb": round(usage / (1024*1024), 2), | |
| "limit_mb": round(limit / (1024*1024), 2), | |
| "percent": round(usage / limit * 100, 1) | |
| } | |
| def predict_contracts_clauses(req: TextRequest): | |
| model = models["contracts_clauses"] | |
| # The SetFit model predicts labels directly (no integer conversion needed) | |
| preds = model.predict([req.text]) | |
| label = preds[0] # Already a string like 'terms' | |
| # Try to get a confidence score using predict_proba if available | |
| score = 1.0 | |
| if hasattr(model, "predict_proba"): | |
| try: | |
| probs = model.predict_proba([req.text])[0] | |
| # model.labels stores the label strings in the order expected by predict_proba | |
| if hasattr(model, "labels") and model.labels is not None: | |
| # Find the index of the predicted label | |
| if label in model.labels: | |
| idx = model.labels.index(label) | |
| score = probs[idx] | |
| else: | |
| score = max(probs) | |
| else: | |
| score = max(probs) | |
| except Exception: | |
| score = 1.0 | |
| return ClassificationResult(label=label, score=round(float(score), 4)) | |
| def predict_nli(req: PairRequest): | |
| inputs = models["nli_tokenizer"]( | |
| req.premise, req.hypothesis, return_tensors="pt", truncation=True | |
| ) | |
| with torch.no_grad(): | |
| logits = models["nli_model"](**inputs).logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| class_id = torch.argmax(probs, dim=-1).item() | |
| return ClassificationResult( | |
| label=models["nli_id2label"][class_id], | |
| score=round(probs[0][class_id].item(), 4) | |
| ) | |
| def predict_risk(req: TextRequest): | |
| inputs = models["risk_tokenizer"]( | |
| req.text, return_tensors="pt", truncation=True, max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| logits = models["risk_model"](**inputs).logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| class_id = torch.argmax(probs, dim=-1).item() | |
| return ClassificationResult( | |
| label=models["risk_id2label"][class_id], | |
| score=round(probs[0][class_id].item(), 4) | |
| ) | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output.last_hidden_state | |
| mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9) | |
| def get_embeddings(req: EmbeddingRequest): | |
| encoded = models["emb_tokenizer"]( | |
| req.texts, padding=True, truncation=True, return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| outputs = models["emb_model"](**encoded) | |
| embeddings = mean_pooling(outputs, encoded["attention_mask"]) | |
| return EmbeddingResult(embeddings=embeddings.tolist()) | |
| def semantic_chunking(req: ChunkRequest): | |
| chunks = models["chunker"].chunk( | |
| text=req.text, | |
| context_window=req.context_window, | |
| percentile_threshold=req.percentile_threshold, | |
| min_chunk_size=req.min_chunk_size | |
| ) | |
| return ChunkResult(chunks=chunks) | |
| def explain_text(req: ExplanationRequest): | |
| tokenizer = models["explain_tokenizer"] | |
| model = models["explain_model"] | |
| # FLAN-T5 models fine-tuned on summarization require the "summarize: " prefix | |
| input_text = f"summarize: {req.text}" | |
| # If the user asks for an 'explain', we can still frame it as an intensive summary | |
| if req.mode == "explain": | |
| input_text = f"summarize in detail: {req.text}" | |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| num_beams=5, | |
| length_penalty=2.0, # Encourage longer generation | |
| no_repeat_ngram_size=3, # Prevent repetition | |
| early_stopping=True | |
| ) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return {"mode": req.mode, "generated_text": result} | |
| def predict_ner(req: NERRequest): | |
| # Default entity types suitable for freelancer contracts | |
| default_types = [ | |
| "freelancer", "company", "contract value", "governing law", | |
| "jurisdiction", "duration", "payment terms", "termination", | |
| "liability", "intellectual property" | |
| ] | |
| labels = req.entity_types if req.entity_types else default_types | |
| # GLiNER expects lowercase labels for optimal performance | |
| labels = [l.lower() for l in labels] | |
| raw_entities = models["ner"].predict_entities(req.text, labels) | |
| return NERResult(entities=[Entity(**ent) for ent in raw_entities]) |