Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import base64 | |
| import pdfplumber | |
| from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, pipeline | |
| import torch | |
| from typing import List, Dict | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Load the pre-trained model and tokenizer for classification | |
| # DistilBERT model - it's important to fine-tune this model for your task, but we'll use it as-is for now | |
| model_name = "distilbert-base-uncased" | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) | |
| model = DistilBertForSequenceClassification.from_pretrained(model_name) | |
| # Use Hugging Face's pipeline for text classification | |
| classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
| # Function to extract text from PDF file | |
| def extract_text_from_pdf(pdf_data: bytes) -> str: | |
| with pdfplumber.open(pdf_data) as pdf: | |
| text = "" | |
| for page in pdf.pages: | |
| text += page.extract_text() or "" | |
| return text | |
| # Define request body structure | |
| class ContractFile(BaseModel): | |
| file: str # Base64-encoded PDF file | |
| async def analyze_contract(data: ContractFile): | |
| try: | |
| # Decode base64 PDF data | |
| pdf_data = base64.b64decode(data.file) | |
| # Extract text from the PDF | |
| contract_text = extract_text_from_pdf(pdf_data) | |
| # Split contract text into clauses (naive split by ".") | |
| clauses = contract_text.split(". ") | |
| # Analyze each clause for risk level using the classifier | |
| results = [] | |
| for clause in clauses: | |
| if clause.strip(): | |
| result = classifier(clause) | |
| risk_score = result[0]["score"] if result[0]["label"] == "POSITIVE" else 1 - result[0]["score"] | |
| risk_level = "High" if risk_score > 0.7 else "Medium" if risk_score > 0.4 else "Low" | |
| results.append({ | |
| "clause": clause, | |
| "risk_level": risk_level, | |
| "risk_score": risk_score | |
| }) | |
| # Calculate the overall risk score for the contract | |
| overall_score = sum(r["risk_score"] for r in results) / len(results) if results else 0 | |
| return { | |
| "clauses": results, | |
| "overall_score": overall_score | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing the contract: {str(e)}") | |
| # Test root endpoint | |
| async def read_root(): | |
| return {"message": "Welcome to the Contract Risk Heatmap Generator API!"} | |