Spaces:
Sleeping
Sleeping
| """ | |
| Cross-Contract Clause Extractor and Pair Generator | |
| Uses Groq API (groq.com) for clause extraction | |
| Feeds pairs into Model 3 (NLI conflict detection) | |
| Install: pip install groq | |
| API key: https://console.groq.com | |
| """ | |
| import os | |
| import json | |
| import torch | |
| from groq import Groq | |
| from transformers import pipeline as hf_pipeline, AutoTokenizer | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| print("API KEY:", GROQ_API_KEY) | |
| MODEL3_DIR = "../model_3" # path to your saved Model 3 | |
| GROQ_MODEL = "openai/gpt-oss-120b" # same model you're already using | |
| MAX_LEN = 512 # must match Model 3 training config | |
| CONF_THRESHOLD = 0.7 # flag pairs below this as uncertain | |
| CLAUSE_TYPES = [ | |
| "termination", | |
| "warranty", | |
| "indemnification", | |
| "ip_ownership", | |
| "dispute_resolution", | |
| "confidentiality", | |
| "liability_cap", | |
| "governing_law", | |
| "payment", | |
| "non_compete", | |
| "force_majeure", | |
| "assignment", | |
| ] | |
| # ββ Groq client βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| # ββ Step 1: Extract clauses from a single contract ββββββββββββββββββββββββββββ | |
| EXTRACTION_SYSTEM_PROMPT = """You are a legal clause extraction engine. | |
| Your job is to extract distinct legal clauses from contract text. | |
| You must return ONLY a valid JSON array β no explanation, no markdown fences, no preamble. | |
| Never extract financial covenant clauses with numeric thresholds β those are handled separately.""" | |
| EXTRACTION_USER_PROMPT = """Extract all distinct legal clauses from this contract. | |
| For each clause return: | |
| - clause_type: one of [{clause_types}] | |
| - clause_text: the core legal obligation rewritten concisely in 1-2 sentences. Max 60 words. Do NOT copy verbatim. | |
| Rules: | |
| - One entry per clause_type maximum. If duplicates exist, keep the most restrictive. | |
| - Skip purely numeric clauses like "maintain debt ratio >= 2.5" β financial covenants only. | |
| - Skip any clause that does not fit the listed types. | |
| Return format β JSON array only, nothing else: | |
| [ | |
| {{"clause_type": "termination", "clause_text": "Either party may terminate with 30 days written notice."}}, | |
| {{"clause_type": "dispute_resolution", "clause_text": "All disputes resolved through binding arbitration in New York."}} | |
| ] | |
| Contract text: | |
| {contract_text}""" | |
| def extract_clauses(contract_text: str, contract_label: str) -> list[dict]: | |
| """ | |
| Call Groq to extract and classify clauses from one contract. | |
| Returns list of {clause_type, clause_text, contract} dicts. | |
| """ | |
| prompt = EXTRACTION_USER_PROMPT.format( | |
| clause_types=", ".join(CLAUSE_TYPES), | |
| contract_text=contract_text.strip() | |
| ) | |
| # Non-streaming β we need the full response before JSON parsing | |
| completion = groq_client.chat.completions.create( | |
| model=GROQ_MODEL, | |
| messages=[ | |
| {"role": "system", "content": EXTRACTION_SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0, # deterministic extraction | |
| max_completion_tokens=2048, | |
| top_p=1, | |
| reasoning_effort="medium", | |
| stream=False, # must be False β need full JSON before parsing | |
| stop=None, | |
| ) | |
| raw = completion.choices[0].message.content.strip() | |
| # Strip markdown fences if model adds them anyway | |
| if raw.startswith("```"): | |
| raw = raw.split("```")[1] | |
| if raw.startswith("json"): | |
| raw = raw[4:] | |
| raw = raw.strip() | |
| try: | |
| clauses = json.loads(raw) | |
| # Handle if model wraps array in a dict | |
| if isinstance(clauses, dict): | |
| clauses = next(iter(clauses.values())) | |
| except json.JSONDecodeError as e: | |
| print(f"[ERROR] JSON parse failed for {contract_label}: {e}") | |
| print(f"Raw response was:\n{raw[:400]}") | |
| return [] | |
| for c in clauses: | |
| c["contract"] = contract_label | |
| print(f"\n[{contract_label}] Extracted {len(clauses)} clauses:") | |
| for c in clauses: | |
| print(f" [{c['clause_type']}] {c['clause_text'][:80]}...") | |
| return clauses | |
| # ββ Step 2: Pair same-type clauses across contracts βββββββββββββββββββββββββββ | |
| def generate_pairs( | |
| clauses_a: list[dict], | |
| clauses_b: list[dict], | |
| ) -> list[dict]: | |
| """ | |
| Match clauses of the same type across Contract A and Contract B. | |
| Returns list of {clause_type, clause_a, clause_b} dicts. | |
| """ | |
| index_a = {c["clause_type"]: c["clause_text"] for c in clauses_a} | |
| index_b = {c["clause_type"]: c["clause_text"] for c in clauses_b} | |
| matched_types = set(index_a.keys()) & set(index_b.keys()) | |
| unmatched_types = set(index_a.keys()).symmetric_difference(set(index_b.keys())) | |
| pairs = [ | |
| { | |
| "clause_type": clause_type, | |
| "clause_a": index_a[clause_type], | |
| "clause_b": index_b[clause_type], | |
| } | |
| for clause_type in matched_types | |
| ] | |
| print(f"\n[PAIRING] {len(pairs)} matching types: {sorted(matched_types)}") | |
| if unmatched_types: | |
| print(f"[PAIRING] Only in one contract (skipped): {sorted(unmatched_types)}") | |
| return pairs | |
| # ββ Step 3: Validate token lengths before inference βββββββββββββββββββββββββββ | |
| def check_token_length(tokenizer, clause_a: str, clause_b: str, max_len: int) -> int: | |
| """Returns token count. Warns if truncation will occur.""" | |
| tokens = tokenizer( | |
| f"{clause_a} [SEP] {clause_b}", | |
| return_tensors="pt", | |
| truncation=False, | |
| ) | |
| length = tokens["input_ids"].shape[1] | |
| if length > max_len: | |
| print(f" [WARN] {length} tokens > MAX_LEN {max_len} β will be truncated") | |
| elif length > int(max_len * 0.85): | |
| print(f" [WARN] {length} tokens is close to limit ({max_len})") | |
| return length | |
| # ββ Step 4: Load and run Model 3 βββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model3(model_dir: str, max_len: int): | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| pipe = hf_pipeline( | |
| "text-classification", | |
| model=model_dir, | |
| tokenizer=tokenizer, | |
| device=device, | |
| top_k=None, | |
| truncation=True, | |
| max_length=max_len, | |
| return_token_type_ids=False # π₯ ADD THIS | |
| ) | |
| print(f"\n[MODEL3] Loaded from '{model_dir}' on {'GPU' if device == 0 else 'CPU'}") | |
| return pipe, tokenizer | |
| def score_pairs( | |
| pairs: list[dict], | |
| pipe, | |
| tokenizer, | |
| max_len: int, | |
| conf_threshold: float, | |
| ) -> list[dict]: | |
| """ | |
| Run Model 3 on each clause pair. | |
| Returns results sorted: contradictions first, then by confidence descending. | |
| """ | |
| results = [] | |
| for pair in pairs: | |
| clause_a = pair["clause_a"] | |
| clause_b = pair["clause_b"] | |
| clause_type = pair["clause_type"] | |
| token_len = check_token_length(tokenizer, clause_a, clause_b, max_len) | |
| raw_result = pipe(f"{clause_a} [SEP] {clause_b}") | |
| if raw_result and isinstance(raw_result[0], list): | |
| raw_result = raw_result[0] | |
| scores = {r["label"]: r["score"] for r in raw_result} | |
| predicted_label = max(scores, key=scores.get) | |
| predicted_score = scores[predicted_label] | |
| contradiction_score = scores.get("contradiction", 0.0) | |
| results.append({ | |
| "clause_type": clause_type, | |
| "clause_a": clause_a, | |
| "clause_b": clause_b, | |
| "predicted_label": predicted_label, | |
| "predicted_score": round(predicted_score, 4), | |
| "contradiction_score": round(contradiction_score, 4), | |
| "all_scores": {k: round(v, 4) for k, v in scores.items()}, | |
| "token_length": token_len, | |
| "uncertain": predicted_score < conf_threshold, | |
| }) | |
| # Contradictions first, then sorted by confidence descending | |
| results.sort(key=lambda x: ( | |
| x["predicted_label"] != "contradiction", | |
| -x["predicted_score"], | |
| )) | |
| # β Keep only strong, reliable contradictions | |
| return results | |
| # ββ Step 5: Print results βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def print_results(results: list[dict]): | |
| # π₯ Split results | |
| strong = [ | |
| r for r in results | |
| if r["predicted_label"] == "contradiction" | |
| and not r["uncertain"] | |
| and r["predicted_score"] >= 0.75 | |
| ] | |
| uncertain = [ | |
| r for r in results | |
| if r["uncertain"] | |
| ] | |
| print("\n" + "=" * 65) | |
| print("CONTRACT CONFLICT ANALYSIS") | |
| print("=" * 65) | |
| # β PRIMARY SECTION | |
| print("\nββ HIGH-CONFIDENCE CONFLICTS βββββββββββββββββββββββββββββ") | |
| print("[INFO] These are reliable contradictions (>= 75%)") | |
| if strong: | |
| for r in strong: | |
| print(f"\n[{r['clause_type'].upper()}] {r['predicted_score']:.2%}") | |
| print(f"Contract A: {r['clause_a']}") | |
| print(f"Contract B: {r['clause_b']}") | |
| else: | |
| print(" None found.") | |
| # β οΈ SECONDARY SECTION | |
| print("\nββ UNCERTAIN / REVIEW NEEDED βββββββββββββββββββββββββββββ") | |
| print("[INFO] Lower-confidence predictions β require human validation") | |
| if uncertain: | |
| for r in uncertain: | |
| print(f"\n[{r['clause_type'].upper()}] " | |
| f"{r['predicted_label']} ({r['predicted_score']:.2%})") | |
| print(f"Contract A: {r['clause_a']}") | |
| print(f"Contract B: {r['clause_b']}") | |
| else: | |
| print(" None.") | |
| print("\n" + "=" * 65) | |
| # ββ Main pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_pipeline( | |
| contract_a_text: str, | |
| contract_b_text: str, | |
| model3_dir: str = MODEL3_DIR, | |
| max_len: int = MAX_LEN, | |
| conf_threshold: float = CONF_THRESHOLD, | |
| ) -> list[dict]: | |
| """ | |
| Full pipeline: | |
| 1. Groq extracts + classifies clauses from both contracts | |
| 2. Same-type clauses are paired across contracts | |
| 3. Model 3 scores each pair for contradiction | |
| 4. Results returned sorted by conflict severity | |
| """ | |
| print("\nββ STEP 1: Extracting clauses via Groq ββββββββββββββββββββ") | |
| clauses_a = extract_clauses(contract_a_text, "Contract A") | |
| clauses_b = extract_clauses(contract_b_text, "Contract B") | |
| if not clauses_a or not clauses_b: | |
| print("[ERROR] Extraction returned empty. Check GROQ_API_KEY and contract text.") | |
| return [] | |
| print("\nββ STEP 2: Generating clause pairs ββββββββββββββββββββββββ") | |
| pairs = generate_pairs(clauses_a, clauses_b) | |
| if not pairs: | |
| print("[WARN] No matching clause types between contracts.") | |
| return [] | |
| print(f"\nββ STEP 3: Scoring {len(pairs)} pairs with Model 3 ββββββββ") | |
| pipe, tokenizer = load_model3(model3_dir, max_len) | |
| results = score_pairs(pairs, pipe, tokenizer, max_len, conf_threshold) | |
| print_results(results) | |
| return results | |
| # ββ Example usage βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| CONTRACT_A = """ | |
| VENDOR AGREEMENT 2024 | |
| Termination: Either party may terminate this agreement for convenience | |
| upon 30 days written notice to the other party. | |
| Warranties: Seller warrants that all deliverables shall be free from | |
| defects for a period of 24 months from the date of acceptance by Buyer. | |
| Dispute Resolution: All disputes arising under this agreement shall be | |
| resolved through binding arbitration in New York under AAA rules. | |
| Intellectual Property: The Licensee is granted an exclusive, worldwide, | |
| perpetual license to use the Software and all derivative works. | |
| Confidentiality: Neither party shall disclose Confidential Information | |
| to any third party without prior written consent of the disclosing party. | |
| Governing Law: This agreement shall be governed by the laws of Delaware. | |
| """ | |
| CONTRACT_B = """ | |
| MASTER SERVICES AGREEMENT 2024 | |
| Termination: This agreement may only be terminated for cause, specifically | |
| material breach that remains uncured for 60 days after written notice. | |
| Warranties: Seller disclaims all warranties, express or implied, including | |
| any warranty of merchantability or fitness for a particular purpose. | |
| Dispute Resolution: Either party may bring suit in any court of competent | |
| jurisdiction to resolve disputes arising under this agreement. | |
| Intellectual Property: The license granted herein is non-exclusive, limited | |
| to the United States, and valid for 12 months only from the effective date. | |
| Confidentiality: Confidential Information must not be shared with outside | |
| parties unless the disclosing party agrees in writing beforehand. | |
| Governing Law: This agreement is governed by the laws of California. | |
| """ | |
| results = run_pipeline(CONTRACT_A, CONTRACT_B) | |
| with open("conflict_results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print("\nResults saved to conflict_results.json") | |