viswanani commited on
Commit
4994596
·
verified ·
1 Parent(s): 1c81c8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -35
app.py CHANGED
@@ -1,50 +1,70 @@
1
- import torch
2
- from transformers import pipeline
3
- from fastapi import FastAPI
4
  import base64
5
  import pdfplumber
 
 
 
 
6
 
7
  # Initialize FastAPI app
8
  app = FastAPI()
9
 
10
- # Load BERT-based model for clause classification
11
  classifier = pipeline("text-classification", model="distilbert-base-uncased")
12
 
13
- # Function to extract text from PDF
14
- def extract_text_from_pdf(pdf_data):
 
15
  with pdfplumber.open(pdf_data) as pdf:
16
  text = ""
17
  for page in pdf.pages:
18
  text += page.extract_text() or ""
19
  return text
20
 
 
 
 
 
 
 
21
  @app.post("/analyze_contract")
22
- async def analyze_contract(file: str):
23
- # Decode base64 PDF
24
- pdf_data = base64.b64decode(file)
25
-
26
- # Extract text
27
- contract_text = extract_text_from_pdf(pdf_data)
28
-
29
- # Split into clauses (simple split for demo; use regex for production)
30
- clauses = contract_text.split(". ")
31
-
32
- # Analyze each clause
33
- results = []
34
- for clause in clauses:
35
- if clause.strip():
36
- result = classifier(clause)
37
- risk_score = result[0]["score"] if result[0]["label"] == "POSITIVE" else 1 - result[0]["score"]
38
- results.append({
39
- "clause": clause,
40
- "risk_level": "High" if risk_score > 0.7 else "Medium" if risk_score > 0.4 else "Low",
41
- "risk_score": risk_score
42
- })
43
-
44
- # Calculate overall risk score
45
- overall_score = sum(r["risk_score"] for r in results) / len(results) if results else 0
46
-
47
- return {
48
- "clauses": results,
49
- "overall_score": overall_score
50
- }
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
 
3
  import base64
4
  import pdfplumber
5
+ from transformers import pipeline
6
+ import torch
7
+ from typing import List, Dict
8
+
9
 
10
  # Initialize FastAPI app
11
  app = FastAPI()
12
 
13
+ # Load the pre-trained BERT model for contract clause classification
14
  classifier = pipeline("text-classification", model="distilbert-base-uncased")
15
 
16
+
17
+ # Function to extract text from PDF file
18
+ def extract_text_from_pdf(pdf_data: bytes) -> str:
19
  with pdfplumber.open(pdf_data) as pdf:
20
  text = ""
21
  for page in pdf.pages:
22
  text += page.extract_text() or ""
23
  return text
24
 
25
+
26
+ # Define request body structure
27
+ class ContractFile(BaseModel):
28
+ file: str # Base64-encoded PDF file
29
+
30
+
31
  @app.post("/analyze_contract")
32
+ async def analyze_contract(data: ContractFile):
33
+ try:
34
+ # Decode base64 PDF data
35
+ pdf_data = base64.b64decode(data.file)
36
+
37
+ # Extract text from the PDF
38
+ contract_text = extract_text_from_pdf(pdf_data)
39
+
40
+ # Split contract text into clauses (naive split by ".")
41
+ clauses = contract_text.split(". ")
42
+
43
+ # Analyze each clause for risk level using the classifier
44
+ results = []
45
+ for clause in clauses:
46
+ if clause.strip():
47
+ result = classifier(clause)
48
+ risk_score = result[0]["score"] if result[0]["label"] == "POSITIVE" else 1 - result[0]["score"]
49
+ risk_level = "High" if risk_score > 0.7 else "Medium" if risk_score > 0.4 else "Low"
50
+ results.append({
51
+ "clause": clause,
52
+ "risk_level": risk_level,
53
+ "risk_score": risk_score
54
+ })
55
+
56
+ # Calculate the overall risk score for the contract
57
+ overall_score = sum(r["risk_score"] for r in results) / len(results) if results else 0
58
+
59
+ return {
60
+ "clauses": results,
61
+ "overall_score": overall_score
62
+ }
63
+ except Exception as e:
64
+ raise HTTPException(status_code=500, detail=f"Error processing the contract: {str(e)}")
65
+
66
+
67
+ # Test root endpoint
68
+ @app.get("/")
69
+ async def read_root():
70
+ return {"message": "Welcome to the Contract Risk Heatmap Generator API!"}