Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.multioutput import MultiOutputClassifier | |
| from sklearn.metrics import accuracy_score | |
| import joblib | |
| import os | |
| import config | |
| app = FastAPI() | |
| # Load model and vectorizer if exist | |
| def load_model_vectorizer(): | |
| if os.path.exists(config.MODEL_PATH) and os.path.exists(config.TFIDF_PATH): | |
| model = joblib.load(config.MODEL_PATH) | |
| vectorizer = joblib.load(config.TFIDF_PATH) | |
| return model, vectorizer | |
| else: | |
| return None, None | |
| # Pydantic model for prediction input (add all your fields as needed) | |
| class TransactionData(BaseModel): | |
| Transaction_Id: str | |
| Origin: str | |
| Designation: str | |
| Keywords: str | |
| Name: str | |
| SWIFT_Tag: str | |
| Currency: str | |
| Entity: str | |
| Message: str | |
| City: str | |
| Country: str | |
| State: str | |
| Hit_Type: str | |
| Record_Matching_String: str | |
| WatchList_Match_String: str | |
| Payment_Sender_Name: Optional[str] = "" | |
| Payment_Reciever_Name: Optional[str] = "" | |
| Swift_Message_Type: str | |
| Text_Sanction_Data: str | |
| Matched_Sanctioned_Entity: str | |
| Red_Flag_Reason: str | |
| Risk_Level: str | |
| Risk_Score: float | |
| CDD_Level: str | |
| PEP_Status: str | |
| Sanction_Description: str | |
| Checker_Notes: str | |
| Sanction_Context: str | |
| async def root(): | |
| return {"message": "Welcome to the Logistic Regression API"} | |
| # --- TRAIN --- | |
| def train(): | |
| try: | |
| os.makedirs(config.MODEL_SAVE_DIR, exist_ok=True) | |
| df = pd.read_csv(config.DATA_PATH) | |
| X = df[config.TEXT_COLUMN] | |
| y = df[config.LABEL_COLUMNS] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| vectorizer = TfidfVectorizer() | |
| X_train_vec = vectorizer.fit_transform(X_train) | |
| X_test_vec = vectorizer.transform(X_test) | |
| model = MultiOutputClassifier(LogisticRegression(max_iter=1000)) | |
| model.fit(X_train_vec, y_train) | |
| y_pred = model.predict(X_test_vec) | |
| accuracy = { | |
| col: accuracy_score(y_test[col], [pred[i] for pred in y_pred]) | |
| for i, col in enumerate(y.columns) | |
| } | |
| joblib.dump(model, config.MODEL_PATH) | |
| joblib.dump(vectorizer, config.TFIDF_PATH) | |
| return {"message": "Training completed", "accuracy": accuracy} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --- PREDICT --- | |
| class PredictionRequest(BaseModel): | |
| transaction_data: TransactionData | |
| def predict(request: PredictionRequest): | |
| try: | |
| input_data = pd.DataFrame([request.transaction_data.dict()]) | |
| # Use .get(..., [''])[0] to avoid KeyError if field is missing | |
| text_input = f""" | |
| Transaction ID: {input_data.get('Transaction_Id', [''])[0]} | |
| Origin: {input_data.get('Origin', [''])[0]} | |
| Designation: {input_data.get('Designation', [''])[0]} | |
| Keywords: {input_data.get('Keywords', [''])[0]} | |
| Name: {input_data.get('Name', [''])[0]} | |
| SWIFT Tag: {input_data.get('SWIFT_Tag', [''])[0]} | |
| Currency: {input_data.get('Currency', [''])[0]} | |
| Entity: {input_data.get('Entity', [''])[0]} | |
| Message: {input_data.get('Message', [''])[0]} | |
| City: {input_data.get('City', [''])[0]} | |
| Country: {input_data.get('Country', [''])[0]} | |
| State: {input_data.get('State', [''])[0]} | |
| Hit Type: {input_data.get('Hit_Type', [''])[0]} | |
| Record Matching String: {input_data.get('Record_Matching_String', [''])[0]} | |
| WatchList Match String: {input_data.get('WatchList_Match_String', [''])[0]} | |
| Payment Sender: {input_data.get('Payment_Sender_Name', [''])[0]} | |
| Payment Receiver: {input_data.get('Payment_Reciever_Name', [''])[0]} | |
| Swift Message Type: {input_data.get('Swift_Message_Type', [''])[0]} | |
| Text Sanction Data: {input_data.get('Text_Sanction_Data', [''])[0]} | |
| Matched Sanctioned Entity: {input_data.get('Matched_Sanctioned_Entity', [''])[0]} | |
| Red Flag Reason: {input_data.get('Red_Flag_Reason', [''])[0]} | |
| Risk Level: {input_data.get('Risk_Level', [''])[0]} | |
| Risk Score: {input_data.get('Risk_Score', [''])[0]} | |
| CDD Level: {input_data.get('CDD_Level', [''])[0]} | |
| PEP Status: {input_data.get('PEP_Status', [''])[0]} | |
| Sanction Description: {input_data.get('Sanction_Description', [''])[0]} | |
| Checker Notes: {input_data.get('Checker_Notes', [''])[0]} | |
| Sanction Context: {input_data.get('Sanction_Context', [''])[0]} | |
| Maker Action: {input_data.get('Maker_Action', [''])[0]} | |
| Customer Type: {input_data.get('Customer_Type', [''])[0]} | |
| Industry: {input_data.get('Industry', [''])[0]} | |
| Transaction Type: {input_data.get('Transaction_Type', [''])[0]} | |
| Transaction Channel: {input_data.get('Transaction_Channel', [''])[0]} | |
| Geographic Origin: {input_data.get('Geographic_Origin', [''])[0]} | |
| Geographic Destination: {input_data.get('Geographic_Destination', [''])[0]} | |
| Risk Category: {input_data.get('Risk_Category', [''])[0]} | |
| Risk Drivers: {input_data.get('Risk_Drivers', [''])[0]} | |
| Alert Status: {input_data.get('Alert_Status', [''])[0]} | |
| Investigation Outcome: {input_data.get('Investigation_Outcome', [''])[0]} | |
| Source of Funds: {input_data.get('Source_Of_Funds', [''])[0]} | |
| Purpose of Transaction: {input_data.get('Purpose_Of_Transaction', [''])[0]} | |
| Beneficial Owner: {input_data.get('Beneficial_Owner', [''])[0]} | |
| """ | |
| # Load model and vectorizer | |
| model = joblib.load(config.MODEL_PATH) | |
| vectorizer = joblib.load(config.TFIDF_PATH) | |
| # Vectorize and predict | |
| X_vec = vectorizer.transform([text_input]) | |
| preds = model.predict(X_vec)[0] | |
| # Response with label names | |
| response = {label: pred for label, pred in zip(config.LABEL_COLUMNS, preds)} | |
| return response | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --- TEST --- | |
| def test(): | |
| if not os.path.exists(config.DATA_PATH): | |
| raise HTTPException(status_code=404, detail="Test data file not found") | |
| df = pd.read_csv(config.DATA_PATH) | |
| model, vectorizer = load_model_vectorizer() | |
| if model is None or vectorizer is None: | |
| raise HTTPException(status_code=400, detail="Model is not trained yet. Please train first.") | |
| X = df[config.TEXT_COLUMN] | |
| y = df[config.LABEL_COLUMNS] | |
| X_vec = vectorizer.transform(X) | |
| y_pred = model.predict(X_vec) | |
| accuracy = { | |
| col: accuracy_score(y[col], [pred[i] for pred in y_pred]) | |
| for i, col in enumerate(y.columns) | |
| } | |
| return {"accuracy": accuracy} | |
| # --- VALIDATE --- | |
| def validate(data: dict): | |
| # Accept dict with list of inputs for batch validation | |
| inputs = data.get("inputs") | |
| if not inputs: | |
| raise HTTPException(status_code=400, detail="No inputs provided") | |
| model, vectorizer = load_model_vectorizer() | |
| if model is None or vectorizer is None: | |
| raise HTTPException(status_code=400, detail="Model is not trained yet. Please train first.") | |
| results = [] | |
| for item in inputs: | |
| # Construct text input string | |
| text_input = "" | |
| for key, value in item.items(): | |
| text_input += f"{key}: {value}\n" | |
| X_vec = vectorizer.transform([text_input]) | |
| preds = model.predict(X_vec)[0] | |
| result = {label: pred for label, pred in zip(config.LABEL_COLUMNS, preds)} | |
| results.append(result) | |
| return {"results": results} | |