Spaces:
Sleeping
Sleeping
File size: 6,001 Bytes
20cbff3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import os
import torch
import logging
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv()
class FraudDetector:
def __init__(self, model_name=None, hf_token=None):
self.model_name = model_name or os.getenv("MODEL_NAME", "austinb/fraud_text_detection")
self.hf_token = hf_token or os.getenv("HUGGINGFACEHUB_API_TOKEN")
self.low_threshold = float(os.getenv("LOW_THRESHOLD", 0.3))
self.high_threshold = float(os.getenv("HIGH_THRESHOLD", 0.7))
self.max_length = int(os.getenv("MAX_LENGTH", 512))
self.tokenizer = None
self.model = None
self.fraud_index = None
if not self.model_name:
raise ValueError("MODEL_NAME not provided and not found in environment variables")
self._load_model()
def _load_model(self):
try:
logger.info(f"Loading model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
token=self.hf_token
)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
token=self.hf_token
)
self.model.eval()
# Detect fraud label index from model config
id2label = self.model.config.id2label
logger.info(f"Model labels: {id2label}")
for idx, label in id2label.items():
if "fraud" in label.lower() or label == "LABEL_1":
self.fraud_index = idx
break
# Fallback: assume index 1 is fraud for binary classifiers
if self.fraud_index is None:
self.fraud_index = 1
logger.warning(
f"Could not detect fraud label from {list(id2label.values())}. "
f"Defaulting to index 1. Set FRAUD_LABEL_INDEX in .env to override."
)
# Allow manual override via env
env_override = os.getenv("FRAUD_LABEL_INDEX")
if env_override is not None:
self.fraud_index = int(env_override)
logger.info(f"Fraud label index overridden by env: {self.fraud_index}")
logger.info(
f"Model loaded. Fraud index: {self.fraud_index} "
f"(label: {id2label.get(self.fraud_index, 'unknown')})"
)
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
def _tokenize(self, texts):
"""Shared tokenizer call with consistent settings."""
return self.tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=self.max_length
)
def get_fraud_score(self, text: str) -> float:
inputs = self._tokenize(text)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
return probs[0][self.fraud_index].item()
def get_fraud_scores(self, texts: list) -> list:
inputs = self._tokenize(texts)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
return probs[:, self.fraud_index].tolist()
def risk_label(self, score: float) -> str:
if score < self.low_threshold:
return "Low Risk"
elif score < self.high_threshold:
return "Medium Risk"
else:
return "High Risk 🚨"
def predict(self, text: str) -> dict:
score = self.get_fraud_score(text)
preview = text[:50] + ("..." if len(text) > 50 else "")
result = {
"text": text,
"fraud_score": round(score, 4),
"risk_level": self.risk_label(score)
}
logger.info(f"Prediction for '{preview}': {result['risk_level']} ({result['fraud_score']})")
return result
def analyze(self, text: str) -> dict:
"""Returns fraud score + risk level + binary detection in one call."""
score = self.get_fraud_score(text)
is_fraud = score >= self.high_threshold
preview = text[:50] + ("..." if len(text) > 50 else "")
result = {
"text": text,
"fraud_score": round(score, 4),
"risk_level": self.risk_label(score),
"is_fraud": is_fraud,
"detection": "Fraud Detected 🚨" if is_fraud else "No Fraud Detected ✅"
}
logger.info(f"Analyze for '{preview}': {result['detection']} | {result['risk_level']} ({result['fraud_score']})")
return result
def predict_batch(self, texts: list) -> list:
"""Batch predict with consistent logging."""
scores = self.get_fraud_scores(texts)
results = []
for text, score in zip(texts, scores):
preview = text[:50] + ("..." if len(text) > 50 else "")
risk = self.risk_label(score)
logger.info(f"Batch prediction for '{preview}': {risk} ({round(score, 4)})")
results.append({
"text": text,
"fraud_score": round(score, 4),
"risk_level": risk
})
return results
# Example Usage
if __name__ == "__main__":
try:
detector = FraudDetector()
sample_text = "User transferred ₹50,000 to an unknown account at midnight"
result = detector.predict(sample_text)
print("\nPrediction Result:")
print(result)
except Exception as e:
print(f"Error: {e}") |