Spaces:
Building
Building
| import os | |
| import torch | |
| import logging | |
| from dotenv import load_dotenv | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| class FraudDetector: | |
| def __init__(self, model_name=None, hf_token=None): | |
| self.model_name = model_name or os.getenv("MODEL_NAME", "austinb/fraud_text_detection") | |
| self.hf_token = hf_token or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| self.low_threshold = float(os.getenv("LOW_THRESHOLD", 0.3)) | |
| self.high_threshold = float(os.getenv("HIGH_THRESHOLD", 0.7)) | |
| self.max_length = int(os.getenv("MAX_LENGTH", 512)) | |
| self.tokenizer = None | |
| self.model = None | |
| self.fraud_index = None | |
| if not self.model_name: | |
| raise ValueError("MODEL_NAME not provided and not found in environment variables") | |
| self._load_model() | |
| def _load_model(self): | |
| try: | |
| logger.info(f"Loading model: {self.model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| token=self.hf_token | |
| ) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| self.model_name, | |
| token=self.hf_token | |
| ) | |
| self.model.eval() | |
| # Detect fraud label index from model config | |
| id2label = self.model.config.id2label | |
| logger.info(f"Model labels: {id2label}") | |
| for idx, label in id2label.items(): | |
| if "fraud" in label.lower() or label == "LABEL_1": | |
| self.fraud_index = idx | |
| break | |
| # Fallback: assume index 1 is fraud for binary classifiers | |
| if self.fraud_index is None: | |
| self.fraud_index = 1 | |
| logger.warning( | |
| f"Could not detect fraud label from {list(id2label.values())}. " | |
| f"Defaulting to index 1. Set FRAUD_LABEL_INDEX in .env to override." | |
| ) | |
| # Allow manual override via env | |
| env_override = os.getenv("FRAUD_LABEL_INDEX") | |
| if env_override is not None: | |
| self.fraud_index = int(env_override) | |
| logger.info(f"Fraud label index overridden by env: {self.fraud_index}") | |
| logger.info( | |
| f"Model loaded. Fraud index: {self.fraud_index} " | |
| f"(label: {id2label.get(self.fraud_index, 'unknown')})" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| def _tokenize(self, texts): | |
| """Shared tokenizer call with consistent settings.""" | |
| return self.tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=self.max_length | |
| ) | |
| def get_fraud_score(self, text: str) -> float: | |
| inputs = self._tokenize(text) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| return probs[0][self.fraud_index].item() | |
| def get_fraud_scores(self, texts: list) -> list: | |
| inputs = self._tokenize(texts) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| return probs[:, self.fraud_index].tolist() | |
| def risk_label(self, score: float) -> str: | |
| if score < self.low_threshold: | |
| return "Low Risk" | |
| elif score < self.high_threshold: | |
| return "Medium Risk" | |
| else: | |
| return "High Risk 🚨" | |
| def predict(self, text: str) -> dict: | |
| score = self.get_fraud_score(text) | |
| preview = text[:50] + ("..." if len(text) > 50 else "") | |
| result = { | |
| "text": text, | |
| "fraud_score": round(score, 4), | |
| "risk_level": self.risk_label(score) | |
| } | |
| logger.info(f"Prediction for '{preview}': {result['risk_level']} ({result['fraud_score']})") | |
| return result | |
| def analyze(self, text: str) -> dict: | |
| """Returns fraud score + risk level + binary detection in one call.""" | |
| score = self.get_fraud_score(text) | |
| is_fraud = score >= self.high_threshold | |
| preview = text[:50] + ("..." if len(text) > 50 else "") | |
| result = { | |
| "text": text, | |
| "fraud_score": round(score, 4), | |
| "risk_level": self.risk_label(score), | |
| "is_fraud": is_fraud, | |
| "detection": "Fraud Detected 🚨" if is_fraud else "No Fraud Detected ✅" | |
| } | |
| logger.info(f"Analyze for '{preview}': {result['detection']} | {result['risk_level']} ({result['fraud_score']})") | |
| return result | |
| def predict_batch(self, texts: list) -> list: | |
| """Batch predict with consistent logging.""" | |
| scores = self.get_fraud_scores(texts) | |
| results = [] | |
| for text, score in zip(texts, scores): | |
| preview = text[:50] + ("..." if len(text) > 50 else "") | |
| risk = self.risk_label(score) | |
| logger.info(f"Batch prediction for '{preview}': {risk} ({round(score, 4)})") | |
| results.append({ | |
| "text": text, | |
| "fraud_score": round(score, 4), | |
| "risk_level": risk | |
| }) | |
| return results | |
| # Example Usage | |
| if __name__ == "__main__": | |
| try: | |
| detector = FraudDetector() | |
| sample_text = "User transferred ₹50,000 to an unknown account at midnight" | |
| result = detector.predict(sample_text) | |
| print("\nPrediction Result:") | |
| print(result) | |
| except Exception as e: | |
| print(f"Error: {e}") |