import os import json from transformers import BertTokenizer, BertForSequenceClassification import torch from flask import Flask, request, jsonify from pathlib import Path from datetime import datetime # Initialize Flask app app = Flask(__name__) # Load pre-trained model and tokenizer MODEL_PATH = "path/to/your/model" # Update with your Hugging Face model path tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) model = BertForSequenceClassification.from_pretrained(MODEL_PATH) # Function to process contract text and classify clauses def classify_clause(contract_text): inputs = tokenizer(contract_text, return_tensors="pt", truncation=True, padding=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = torch.argmax(logits, dim=-1).item() # Get predicted class (risk tag) # Define risk labels (assuming 3 risk levels: low, medium, high) risk_labels = ["low", "medium", "high"] predicted_risk = risk_labels[predicted_class] # Get confidence score (softmax output) softmax = torch.nn.Softmax(dim=-1) confidence = softmax(logits).squeeze().tolist()[predicted_class] return {"predicted_risk": predicted_risk, "confidence_score": confidence} # Define route to handle file uploads @app.route("/upload_contract", methods=["POST"]) def upload_contract(): # Extract file from the request if 'file' not in request.files: return jsonify({"error": "No file part"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"error": "No selected file"}), 400 contract_text = file.read().decode('utf-8') # Assuming the file is a text-based contract # Classify the contract text result = classify_clause(contract_text) # Prepare JSON response response_data = { "contract_title": "Sample Contract", # Placeholder, can be parsed from the file "overall_risk_score": result["predicted_risk"], # Risk classification "high_risk_clauses": ["Termination Clause", "Penalty Clause"], # Example (this should be dynamically extracted) "risk_map_url": "https://example.com/risk_map", # Placeholder (use actual URL for visualization) "evaluation_date": datetime.now().strftime("%Y-%m-%d") } # Return response as JSON return jsonify(response_data) if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=5000)