Spaces:
Runtime error
Runtime error
File size: 2,623 Bytes
4994596 2bcddaa 2d98bd2 4994596 2bcddaa 2d98bd2 2bcddaa 2d98bd2 4994596 2bcddaa 4994596 2bcddaa 4994596 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import pdfplumber
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, pipeline
import torch
from typing import List, Dict
# Initialize FastAPI app
app = FastAPI()
# Load the pre-trained model and tokenizer for classification
# DistilBERT model - it's important to fine-tune this model for your task, but we'll use it as-is for now
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
# Use Hugging Face's pipeline for text classification
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
# Function to extract text from PDF file
def extract_text_from_pdf(pdf_data: bytes) -> str:
with pdfplumber.open(pdf_data) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text() or ""
return text
# Define request body structure
class ContractFile(BaseModel):
file: str # Base64-encoded PDF file
@app.post("/analyze_contract")
async def analyze_contract(data: ContractFile):
try:
# Decode base64 PDF data
pdf_data = base64.b64decode(data.file)
# Extract text from the PDF
contract_text = extract_text_from_pdf(pdf_data)
# Split contract text into clauses (naive split by ".")
clauses = contract_text.split(". ")
# Analyze each clause for risk level using the classifier
results = []
for clause in clauses:
if clause.strip():
result = classifier(clause)
risk_score = result[0]["score"] if result[0]["label"] == "POSITIVE" else 1 - result[0]["score"]
risk_level = "High" if risk_score > 0.7 else "Medium" if risk_score > 0.4 else "Low"
results.append({
"clause": clause,
"risk_level": risk_level,
"risk_score": risk_score
})
# Calculate the overall risk score for the contract
overall_score = sum(r["risk_score"] for r in results) / len(results) if results else 0
return {
"clauses": results,
"overall_score": overall_score
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing the contract: {str(e)}")
# Test root endpoint
@app.get("/")
async def read_root():
return {"message": "Welcome to the Contract Risk Heatmap Generator API!"}
|