from fastapi import FastAPI, HTTPException from pydantic import BaseModel import base64 import pdfplumber from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, pipeline import torch from typing import List, Dict # Initialize FastAPI app app = FastAPI() # Load the pre-trained model and tokenizer for classification # DistilBERT model - it's important to fine-tune this model for your task, but we'll use it as-is for now model_name = "distilbert-base-uncased" tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) model = DistilBertForSequenceClassification.from_pretrained(model_name) # Use Hugging Face's pipeline for text classification classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) # Function to extract text from PDF file def extract_text_from_pdf(pdf_data: bytes) -> str: with pdfplumber.open(pdf_data) as pdf: text = "" for page in pdf.pages: text += page.extract_text() or "" return text # Define request body structure class ContractFile(BaseModel): file: str # Base64-encoded PDF file @app.post("/analyze_contract") async def analyze_contract(data: ContractFile): try: # Decode base64 PDF data pdf_data = base64.b64decode(data.file) # Extract text from the PDF contract_text = extract_text_from_pdf(pdf_data) # Split contract text into clauses (naive split by ".") clauses = contract_text.split(". ") # Analyze each clause for risk level using the classifier results = [] for clause in clauses: if clause.strip(): result = classifier(clause) risk_score = result[0]["score"] if result[0]["label"] == "POSITIVE" else 1 - result[0]["score"] risk_level = "High" if risk_score > 0.7 else "Medium" if risk_score > 0.4 else "Low" results.append({ "clause": clause, "risk_level": risk_level, "risk_score": risk_score }) # Calculate the overall risk score for the contract overall_score = sum(r["risk_score"] for r in results) / len(results) if results else 0 return { "clauses": results, "overall_score": overall_score } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing the contract: {str(e)}") # Test root endpoint @app.get("/") async def read_root(): return {"message": "Welcome to the Contract Risk Heatmap Generator API!"}