Komal133's picture
Create app.py
4a567d3 verified
raw
history blame
2.47 kB
import os
import json
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from flask import Flask, request, jsonify
from pathlib import Path
from datetime import datetime
# Initialize Flask app
app = Flask(__name__)
# Load pre-trained model and tokenizer
MODEL_PATH = "path/to/your/model" # Update with your Hugging Face model path
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
# Function to process contract text and classify clauses
def classify_clause(contract_text):
inputs = tokenizer(contract_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item() # Get predicted class (risk tag)
# Define risk labels (assuming 3 risk levels: low, medium, high)
risk_labels = ["low", "medium", "high"]
predicted_risk = risk_labels[predicted_class]
# Get confidence score (softmax output)
softmax = torch.nn.Softmax(dim=-1)
confidence = softmax(logits).squeeze().tolist()[predicted_class]
return {"predicted_risk": predicted_risk, "confidence_score": confidence}
# Define route to handle file uploads
@app.route("/upload_contract", methods=["POST"])
def upload_contract():
# Extract file from the request
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
contract_text = file.read().decode('utf-8') # Assuming the file is a text-based contract
# Classify the contract text
result = classify_clause(contract_text)
# Prepare JSON response
response_data = {
"contract_title": "Sample Contract", # Placeholder, can be parsed from the file
"overall_risk_score": result["predicted_risk"], # Risk classification
"high_risk_clauses": ["Termination Clause", "Penalty Clause"], # Example (this should be dynamically extracted)
"risk_map_url": "https://example.com/risk_map", # Placeholder (use actual URL for visualization)
"evaluation_date": datetime.now().strftime("%Y-%m-%d")
}
# Return response as JSON
return jsonify(response_data)
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5000)