Spaces:
Build error
Build error
| # fraud_detector.py | |
| import torch | |
| import pandas as pd | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from torch.utils.data import Dataset | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| class FinancialFraudDataset(Dataset): | |
| def __init__(self, encodings, labels): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, idx): | |
| item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
| item["labels"] = torch.tensor(self.labels[idx]) | |
| return item | |
| class FinancialFraudTrainer: | |
| def __init__(self, data_path=None): | |
| self.data_path = data_path | |
| self.tokenizer = None | |
| self.model = None | |
| def load_model(self): | |
| # 從 Hugging Face 模型倉庫載入(或用你訓練好的模型路徑) | |
| self.model = BertForSequenceClassification.from_pretrained("hfl/chinese-roberta-wwm-ext", num_labels=2) | |
| self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") | |
| self.model.eval() | |
| def predict_transaction(self, text): | |
| try: | |
| self.model.eval() | |
| inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| prediction = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][prediction].item() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |