viswanani's picture
Update app.py
2d98bd2 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import pdfplumber
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, pipeline
import torch
from typing import List, Dict
# Initialize FastAPI app
app = FastAPI()
# Load the pre-trained model and tokenizer for classification
# DistilBERT model - it's important to fine-tune this model for your task, but we'll use it as-is for now
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
# Use Hugging Face's pipeline for text classification
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
# Function to extract text from PDF file
def extract_text_from_pdf(pdf_data: bytes) -> str:
with pdfplumber.open(pdf_data) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text() or ""
return text
# Define request body structure
class ContractFile(BaseModel):
file: str # Base64-encoded PDF file
@app.post("/analyze_contract")
async def analyze_contract(data: ContractFile):
try:
# Decode base64 PDF data
pdf_data = base64.b64decode(data.file)
# Extract text from the PDF
contract_text = extract_text_from_pdf(pdf_data)
# Split contract text into clauses (naive split by ".")
clauses = contract_text.split(". ")
# Analyze each clause for risk level using the classifier
results = []
for clause in clauses:
if clause.strip():
result = classifier(clause)
risk_score = result[0]["score"] if result[0]["label"] == "POSITIVE" else 1 - result[0]["score"]
risk_level = "High" if risk_score > 0.7 else "Medium" if risk_score > 0.4 else "Low"
results.append({
"clause": clause,
"risk_level": risk_level,
"risk_score": risk_score
})
# Calculate the overall risk score for the contract
overall_score = sum(r["risk_score"] for r in results) / len(results) if results else 0
return {
"clauses": results,
"overall_score": overall_score
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing the contract: {str(e)}")
# Test root endpoint
@app.get("/")
async def read_root():
return {"message": "Welcome to the Contract Risk Heatmap Generator API!"}