File size: 6,001 Bytes
20cbff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import torch
import logging
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

load_dotenv()

class FraudDetector:
    def __init__(self, model_name=None, hf_token=None):
        self.model_name = model_name or os.getenv("MODEL_NAME", "austinb/fraud_text_detection")
        self.hf_token = hf_token or os.getenv("HUGGINGFACEHUB_API_TOKEN")
        self.low_threshold = float(os.getenv("LOW_THRESHOLD", 0.3))
        self.high_threshold = float(os.getenv("HIGH_THRESHOLD", 0.7))
        self.max_length = int(os.getenv("MAX_LENGTH", 512))

        self.tokenizer = None
        self.model = None
        self.fraud_index = None

        if not self.model_name:
            raise ValueError("MODEL_NAME not provided and not found in environment variables")

        self._load_model()

    def _load_model(self):
        try:
            logger.info(f"Loading model: {self.model_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                token=self.hf_token
            )
            self.model = AutoModelForSequenceClassification.from_pretrained(
                self.model_name,
                token=self.hf_token
            )
            self.model.eval()

            # Detect fraud label index from model config
            id2label = self.model.config.id2label
            logger.info(f"Model labels: {id2label}")

            for idx, label in id2label.items():
                if "fraud" in label.lower() or label == "LABEL_1":
                    self.fraud_index = idx
                    break

            # Fallback: assume index 1 is fraud for binary classifiers
            if self.fraud_index is None:
                self.fraud_index = 1
                logger.warning(
                    f"Could not detect fraud label from {list(id2label.values())}. "
                    f"Defaulting to index 1. Set FRAUD_LABEL_INDEX in .env to override."
                )

            # Allow manual override via env
            env_override = os.getenv("FRAUD_LABEL_INDEX")
            if env_override is not None:
                self.fraud_index = int(env_override)
                logger.info(f"Fraud label index overridden by env: {self.fraud_index}")

            logger.info(
                f"Model loaded. Fraud index: {self.fraud_index} "
                f"(label: {id2label.get(self.fraud_index, 'unknown')})"
            )

        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

    def _tokenize(self, texts):
        """Shared tokenizer call with consistent settings."""
        return self.tokenizer(
            texts,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=self.max_length
        )

    def get_fraud_score(self, text: str) -> float:
        inputs = self._tokenize(text)
        with torch.no_grad():
            outputs = self.model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)
        return probs[0][self.fraud_index].item()

    def get_fraud_scores(self, texts: list) -> list:
        inputs = self._tokenize(texts)
        with torch.no_grad():
            outputs = self.model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)
        return probs[:, self.fraud_index].tolist()

    def risk_label(self, score: float) -> str:
        if score < self.low_threshold:
            return "Low Risk"
        elif score < self.high_threshold:
            return "Medium Risk"
        else:
            return "High Risk 🚨"

    def predict(self, text: str) -> dict:
        score = self.get_fraud_score(text)
        preview = text[:50] + ("..." if len(text) > 50 else "")
        result = {
            "text": text,
            "fraud_score": round(score, 4),
            "risk_level": self.risk_label(score)
        }
        logger.info(f"Prediction for '{preview}': {result['risk_level']} ({result['fraud_score']})")
        return result

    def analyze(self, text: str) -> dict:
        """Returns fraud score + risk level + binary detection in one call."""
        score = self.get_fraud_score(text)
        is_fraud = score >= self.high_threshold
        preview = text[:50] + ("..." if len(text) > 50 else "")
        result = {
            "text": text,
            "fraud_score": round(score, 4),
            "risk_level": self.risk_label(score),
            "is_fraud": is_fraud,
            "detection": "Fraud Detected 🚨" if is_fraud else "No Fraud Detected ✅"
        }
        logger.info(f"Analyze for '{preview}': {result['detection']} | {result['risk_level']} ({result['fraud_score']})")
        return result

    def predict_batch(self, texts: list) -> list:
        """Batch predict with consistent logging."""
        scores = self.get_fraud_scores(texts)
        results = []
        for text, score in zip(texts, scores):
            preview = text[:50] + ("..." if len(text) > 50 else "")
            risk = self.risk_label(score)
            logger.info(f"Batch prediction for '{preview}': {risk} ({round(score, 4)})")
            results.append({
                "text": text,
                "fraud_score": round(score, 4),
                "risk_level": risk
            })
        return results


# Example Usage
if __name__ == "__main__":
    try:
        detector = FraudDetector()
        sample_text = "User transferred ₹50,000 to an unknown account at midnight"
        result = detector.predict(sample_text)
        print("\nPrediction Result:")
        print(result)
    except Exception as e:
        print(f"Error: {e}")