prem / train_rf.py
NitishStark's picture
Deploy source code to Hugging Face without binaries
c25dcd7
import pandas as pd
import numpy as np
from datasets import load_dataset
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import m2cgen as m2c
import re
print("Loading dataset alanjoshua2005/india-spam-sms...")
dataset = load_dataset("alanjoshua2005/india-spam-sms")
# Usually dataset has 'train' split which is a list of dicts. Let's make a dataframe
df = pd.DataFrame(dataset['train'])
# Let's see columns
print(df.columns)
# Typically it has 'v1' mapping to label and 'v2' mapping to text, or 'label' and 'text', or 'Message' and 'Category'
if 'Message' in df.columns and 'Category' in df.columns:
df['text'] = df['Message']
df['label'] = df['Category'].apply(lambda x: 1 if x == 'spam' else 0)
elif 'text' in df.columns and 'label' in df.columns:
pass
elif 'v1' in df.columns and 'v2' in df.columns:
df['text'] = df['v2']
df['label'] = df['v1'].apply(lambda x: 1 if x == 'spam' else 0)
else:
# try generically
for col in df.columns:
if 'text' in col.lower() or 'msg' in col.lower() or 'message' in col.lower():
df['text'] = df[col]
if 'label' in col.lower() or 'spam' in col.lower() or 'category' in col.lower():
df['label'] = df[col]
if df['label'].dtype == object:
df['label'] = df['label'].apply(lambda x: 1 if 'spam' in str(x).lower() else 0)
print(df.head())
def extract_features(row):
text = str(row['text']).lower()
# 1. senderAgeScore (mock context)
# 2. senderInContacts
# 3. senderMessageCount
# 4. senderCarrierScore
# We will simulate these based on label, because text doesn't have it.
is_spam = row['label'] == 1
senderAgeScore = np.random.uniform(0.0, 0.3) if is_spam else np.random.uniform(0.5, 1.0)
senderInContacts = 0.0 if is_spam else float(np.random.rand() > 0.2)
senderMessageCount = np.random.randint(0, 3) if is_spam else np.random.randint(10, 100)
senderCarrierScore = np.random.uniform(0.0, 0.4) if is_spam else np.random.uniform(0.6, 1.0)
# 5. urgencyScore (0.0 to 1.0)
urgent_words = ['urgent', 'hurry', 'limited', 'immediate', 'action required', 'alert', 'warning', 'expires', 'claim']
urgencyScore = sum([1 for w in urgent_words if w in text]) / len(urgent_words)
# 6. domainCount
domain_matches = re.findall(r"(http[s]?://|www\.)[^\s]+", text)
domainCount = len(domain_matches)
# 7. domainLevenshtein (simulated, no actual domain extraction here)
domainLevenshtein = np.random.uniform(0.2, 0.9) if (is_spam and domainCount > 0) else 1.0
# 8. containsMoneyAction
money_words = ['rs', '₹', 'free', 'win', 'prize', 'cash', 'money', 'credit', 'loan', 'offer', 'pay', 'collect', 'bank']
containsMoneyAction = float(any(w in text for w in money_words))
# 9. capsRatio
orig_text = str(row['text'])
caps_count = sum(1 for c in orig_text if c.isupper())
letters_count = sum(1 for c in orig_text if c.isalpha())
capsRatio = caps_count / letters_count if letters_count > 0 else 0.0
# 10. hourOfDay (spam often sent at odd hours)
hourOfDay = float(np.random.choice([0,1,2,3,4,22,23])) if is_spam else float(np.random.randint(8, 20))
# 11. isWeekend
isWeekend = float(np.random.rand() > (0.3 if is_spam else 0.8))
# 12. isFestivalPeriod
isFestivalPeriod = float(np.random.rand() > 0.8)
# 13. recentScamSenderSim
recentScamSenderSim = np.random.uniform(0.5, 1.0) if is_spam else np.random.uniform(0.0, 0.2)
return [
senderAgeScore, senderInContacts, float(senderMessageCount), senderCarrierScore,
urgencyScore, float(domainCount), domainLevenshtein, containsMoneyAction, capsRatio,
hourOfDay, isWeekend, isFestivalPeriod, recentScamSenderSim
]
print("Extracting features...")
X = np.array([extract_features(row) for _, row in df.iterrows()])
y = df['label'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("Training Random Forest...")
clf = RandomForestClassifier(n_estimators=15, max_depth=5, random_state=42)
clf.fit(X_train, y_train)
print("Evaluating...")
preds = clf.predict(X_test)
print(classification_report(y_test, preds))
print("Exporting model to Dart with m2cgen...")
dart_code = m2c.export_to_dart(clf)
# Write to our model file
file_path = "lib/kavacha/models/random_forest_model.dart"
header = """// Auto-generated using m2cgen
import 'sms_feature_vector.dart';
class RandomForestModel {
double score(SmsFeatureVector features) {
List<double> input = [
features.senderAgeScore,
features.senderInContacts ? 1.0 : 0.0,
features.senderMessageCount.toDouble(),
features.senderCarrierScore,
features.urgencyScore,
features.domainCount.toDouble(),
features.domainLevenshtein,
features.containsMoneyAction ? 1.0 : 0.0,
features.capsRatio,
features.hourOfDay.toDouble(),
features.isWeekend ? 1.0 : 0.0,
features.isFestivalPeriod ? 1.0 : 0.0,
features.recentScamSenderSim
];
List<double> output = score_features(input);
// output[1] is the probability of class 1 (spam)
return output[1];
}
"""
with open(file_path, "w") as f:
f.write(header)
# The generated code usually has a class or double[] score(double[] input)
# We will just write it straight out, maybe replace the class wrapper
code = dart_code.replace("double[]", "List<double>").replace("public double[] score", "List<double> score_features")
# m2cgen generates a single function / class. We strip "public class Model {" etc if present.
# Luckily Dart export is typically `List<double> score(List<double> input)`
f.write(code)
f.write("\n}\n")
print("Done writing to " + file_path)