Fraud-API / fraud_model.py
VishalBhagat01's picture
Upload 6 files
20cbff3 verified
import os
import torch
import logging
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv()
class FraudDetector:
def __init__(self, model_name=None, hf_token=None):
self.model_name = model_name or os.getenv("MODEL_NAME", "austinb/fraud_text_detection")
self.hf_token = hf_token or os.getenv("HUGGINGFACEHUB_API_TOKEN")
self.low_threshold = float(os.getenv("LOW_THRESHOLD", 0.3))
self.high_threshold = float(os.getenv("HIGH_THRESHOLD", 0.7))
self.max_length = int(os.getenv("MAX_LENGTH", 512))
self.tokenizer = None
self.model = None
self.fraud_index = None
if not self.model_name:
raise ValueError("MODEL_NAME not provided and not found in environment variables")
self._load_model()
def _load_model(self):
try:
logger.info(f"Loading model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
token=self.hf_token
)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
token=self.hf_token
)
self.model.eval()
# Detect fraud label index from model config
id2label = self.model.config.id2label
logger.info(f"Model labels: {id2label}")
for idx, label in id2label.items():
if "fraud" in label.lower() or label == "LABEL_1":
self.fraud_index = idx
break
# Fallback: assume index 1 is fraud for binary classifiers
if self.fraud_index is None:
self.fraud_index = 1
logger.warning(
f"Could not detect fraud label from {list(id2label.values())}. "
f"Defaulting to index 1. Set FRAUD_LABEL_INDEX in .env to override."
)
# Allow manual override via env
env_override = os.getenv("FRAUD_LABEL_INDEX")
if env_override is not None:
self.fraud_index = int(env_override)
logger.info(f"Fraud label index overridden by env: {self.fraud_index}")
logger.info(
f"Model loaded. Fraud index: {self.fraud_index} "
f"(label: {id2label.get(self.fraud_index, 'unknown')})"
)
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
def _tokenize(self, texts):
"""Shared tokenizer call with consistent settings."""
return self.tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=self.max_length
)
def get_fraud_score(self, text: str) -> float:
inputs = self._tokenize(text)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
return probs[0][self.fraud_index].item()
def get_fraud_scores(self, texts: list) -> list:
inputs = self._tokenize(texts)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
return probs[:, self.fraud_index].tolist()
def risk_label(self, score: float) -> str:
if score < self.low_threshold:
return "Low Risk"
elif score < self.high_threshold:
return "Medium Risk"
else:
return "High Risk 🚨"
def predict(self, text: str) -> dict:
score = self.get_fraud_score(text)
preview = text[:50] + ("..." if len(text) > 50 else "")
result = {
"text": text,
"fraud_score": round(score, 4),
"risk_level": self.risk_label(score)
}
logger.info(f"Prediction for '{preview}': {result['risk_level']} ({result['fraud_score']})")
return result
def analyze(self, text: str) -> dict:
"""Returns fraud score + risk level + binary detection in one call."""
score = self.get_fraud_score(text)
is_fraud = score >= self.high_threshold
preview = text[:50] + ("..." if len(text) > 50 else "")
result = {
"text": text,
"fraud_score": round(score, 4),
"risk_level": self.risk_label(score),
"is_fraud": is_fraud,
"detection": "Fraud Detected 🚨" if is_fraud else "No Fraud Detected ✅"
}
logger.info(f"Analyze for '{preview}': {result['detection']} | {result['risk_level']} ({result['fraud_score']})")
return result
def predict_batch(self, texts: list) -> list:
"""Batch predict with consistent logging."""
scores = self.get_fraud_scores(texts)
results = []
for text, score in zip(texts, scores):
preview = text[:50] + ("..." if len(text) > 50 else "")
risk = self.risk_label(score)
logger.info(f"Batch prediction for '{preview}': {risk} ({round(score, 4)})")
results.append({
"text": text,
"fraud_score": round(score, 4),
"risk_level": risk
})
return results
# Example Usage
if __name__ == "__main__":
try:
detector = FraudDetector()
sample_text = "User transferred ₹50,000 to an unknown account at midnight"
result = detector.predict(sample_text)
print("\nPrediction Result:")
print(result)
except Exception as e:
print(f"Error: {e}")