datasciencesage's picture
device fixed
a529e2b verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import os
from datetime import datetime
import re
from textblob import TextBlob
from preprocessing_test import Preprocessor
from src.model import EnergyMPNN
from loguru import logger
from huggingface_hub import HfApi, HfFolder, upload_folder, Repository,hf_hub_download
# Default values
default_values = {
'sec_id': 'MS4wLjABAAAAUfldJaS79jt92MrYh5qLtoGwq7okyY7wAB...',
'create_time': 1692246654.0,
'height': 720,
'width': 720,
'ratio': '720p',
'duration': 29.0,
'digg_count': 192897,
'share_count': 4180,
'music_count': 0,
'play_count': 2155824,
'comment_count': 2496,
'forward_count': 0,
'download_count': 39,
'desc': '☠️☠️☠️ #aleksandrsorokin #sorokin #ultrarunner...',
'title': 'original sound - musterpoint77',
'share_title': 'Check out Fad.Run’s video! #TikTok >',
'favoriting_count': 1126,
'follower_count': 22741,
'following_count': 142,
'gender': 0,
'has_email': False,
'is_mute': 0,
'language': 'id',
'mention_status': 1,
'user_rate': 1,
'aweme_count': 50,
'birthday': '1900-01-01',
'friends_status': 0,
'signature': 'Run Enthusiast',
'total_favorited': 2051954,
'id_str': 7268144018282580992.0,
'topic': 'diy'
}
# Expected types based on schema
expected_types = {
'sec_id': str,
'create_time': float,
'height': int,
'width': int,
'ratio': str,
'duration': float,
'digg_count': int,
'share_count': int,
'music_count': int,
'play_count': int,
'comment_count': int,
'forward_count': int,
'download_count': int,
'desc': str,
'title': str,
'share_title': str,
'favoriting_count': int,
'follower_count': int,
'following_count': int,
'gender': int,
'has_email': bool,
'is_mute': int,
'language': str,
'mention_status': int,
'user_rate': int,
'aweme_count': int,
'birthday': str,
'friends_status': int,
'signature': str,
'total_favorited': int,
'id_str': float,
'topic': str
}
# Preprocess single-row DataFrame
def preprocess_single_row(df_row, scaler_user=None, scaler_topic=None, scaler_edge=None):
required_columns = [
'sec_id', 'topic', 'create_days_since_creation', 'post_length',
'sentiment_score', 'lexical_diversity', 'create_hour',
'time_since_prev_post', 'lexical_similarity', 'digg_count',
'comment_count', 'share_count'
]
if not all(col in df_row.columns for col in required_columns):
missing = [col for col in required_columns if col not in df_row.columns]
raise ValueError(f"Missing required columns: {missing}")
user_features = pd.DataFrame({
'sec_id': [df_row['sec_id'].iloc[0]],
'create_days_since_creation': [max(df_row['create_days_since_creation'].iloc[0], 1)],
'topic': [1], # Assuming topic is encoded (e.g., 'diy' -> 1)
'post_length': [df_row['post_length'].iloc[0]],
'sentiment_score': [df_row['sentiment_score'].iloc[0]],
'lexical_diversity': [df_row['lexical_diversity'].iloc[0]]
})
user_features['posting_frequency'] = 1 / user_features['create_days_since_creation']
user_node_features = user_features[[
'posting_frequency', 'topic', 'post_length',
'sentiment_score', 'lexical_diversity'
]].values
user_node_features = np.hstack([user_node_features, np.zeros((1, 1))])
if scaler_user:
user_node_features = scaler_user.transform(user_node_features)
else:
user_node_features = np.nan_to_num(user_node_features)
topic_features = pd.DataFrame({
'topic': [df_row['topic'].iloc[0]],
'popularity': [1],
'sentiment_mean': [df_row['sentiment_score'].iloc[0]],
'sentiment_var': [0],
'digg_count_mean': [df_row['digg_count'].iloc[0]],
'comment_count_mean': [df_row['comment_count'].iloc[0]],
'share_count_mean': [df_row['share_count'].iloc[0]]
})
topic_node_features = topic_features[[
'popularity', 'sentiment_mean', 'sentiment_var',
'digg_count_mean', 'comment_count_mean', 'share_count_mean'
]].values
if scaler_topic:
topic_node_features = scaler_topic.transform(topic_node_features)
else:
topic_node_features = np.nan_to_num(topic_node_features)
node_features = np.vstack([user_node_features, topic_node_features])
node_features = torch.tensor(np.nan_to_num(node_features), dtype=torch.float32)
edge_columns = ['post_length', 'sentiment_score', 'create_hour', 'time_since_prev_post', 'lexical_similarity']
edge_features = np.array([[df_row[col].iloc[0] for col in edge_columns]])
edge_features = np.repeat(edge_features, 2, axis=0)
if scaler_edge:
edge_features = scaler_edge.transform(edge_features)
else:
edge_features = np.nan_to_num(edge_features)
edge_features = torch.tensor(edge_features, dtype=torch.float32)
edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long).t().contiguous()
position_vectors = torch.randn(2, 3)
y = torch.tensor([0], dtype=torch.float32)
data = Data(
x=node_features,
edge_index=edge_index,
edge_attr=edge_features,
y=y,
pos=position_vectors
)
data.num_users = 1
return data
# Predict function
def predict_single_row(df_row, model_path='output_files/model_outputs/model_checkpoint/best_model.pth', scaler_user=None, scaler_topic=None, scaler_edge=None):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
data = preprocess_single_row(df_row, scaler_user, scaler_topic, scaler_edge)
data = data.to(device)
except Exception as e:
raise RuntimeError(f"Preprocessing failed: {str(e)}")
model = EnergyMPNN(
input_node_dim=6,
edge_dim=5,
hidden_dim=64,
pos_dim=3,
num_layers=2,
dropout=0.2
).to(device)
try:
repo_id = "Askhedi/fake_user_detection"
model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# model.load_state_dict(torch.load(model_path, map_location=device))
except Exception as e:
raise RuntimeError(f"Failed to load model from {model_path}: {str(e)}")
model.eval()
try:
with torch.no_grad():
user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users)
prob = torch.sigmoid(user_scores).item()
metrics_path = 'output_files/model_outputs/test_metrics.csv'
threshold = 0.5
if os.path.exists(metrics_path):
try:
threshold = pd.read_csv(metrics_path)['threshold'].iloc[0]
except Exception as e:
logger.warning(f"Could not load threshold: {str(e)}. Using default 0.5.")
pred = 1 if prob > threshold else 0
except Exception as e:
raise RuntimeError(f"Prediction failed: {str(e)}")
return prob, pred
# Gradio prediction function
def predict_fake_user(
sec_id, create_time, height, width, ratio, duration, digg_count,
share_count, music_count, play_count, comment_count, forward_count,
download_count, desc, title, share_title, favoriting_count,
follower_count, following_count, gender, has_email, is_mute,
language, mention_status, user_rate, aweme_count, birthday,
friends_status, signature, total_favorited, id_str, topic
):
input_dict = {}
inputs = {
'sec_id': sec_id,
'create_time': create_time,
'height': height,
'width': width,
'ratio': ratio,
'duration': duration,
'digg_count': digg_count,
'share_count': share_count,
'music_count': music_count,
'play_count': play_count,
'comment_count': comment_count,
'forward_count': forward_count,
'download_count': download_count,
'desc': desc,
'title': title,
'share_title': share_title,
'favoriting_count': favoriting_count,
'follower_count': follower_count,
'following_count': following_count,
'gender': gender,
'has_email': has_email,
'is_mute': is_mute,
'language': language,
'mention_status': mention_status,
'user_rate': user_rate,
'aweme_count': aweme_count,
'birthday': birthday,
'friends_status': friends_status,
'signature': signature,
'total_favorited': total_favorited,
'id_str': id_str,
'topic': topic
}
default_used = []
type_errors = []
# Validate types and track defaults
for key, value in inputs.items():
# Check if value is missing (use default)
if value is None or value == "" or (isinstance(value, float) and np.isnan(value)):
input_dict[key] = default_values[key]
default_used.append(key)
else:
input_dict[key] = value
expected_type = expected_types[key]
try:
# Validate type
if expected_type == str:
if not isinstance(value, str):
type_errors.append(f"'{key}' has value '{value}', expected string")
elif expected_type == int:
# For Dropdown inputs, value may be str ("0", "1")
if isinstance(value, str) and value.isdigit():
value = int(value) # Convert valid string to int
elif not isinstance(value, (int, float)) or (isinstance(value, float) and not value.is_integer()):
type_errors.append(f"'{key}' has value '{value}', expected integer")
elif expected_type == float:
if isinstance(value, str):
float(value) # Try converting string to float
elif not isinstance(value, (int, float)):
type_errors.append(f"'{key}' has value '{value}', expected float")
elif expected_type == bool:
if not isinstance(value, bool):
type_errors.append(f"'{key}' has value '{value}', expected boolean")
except (ValueError, TypeError):
type_errors.append(f"'{key}' has value '{value}', expected {expected_type.__name__}")
# Return type errors if any
if type_errors:
error_msg = "Input errors:\n" + "\n".join(f"- {err}" for err in type_errors)
logger.error(error_msg)
return error_msg, None, None
# Cast inputs to correct types
try:
input_dict['create_time'] = float(input_dict['create_time'])
input_dict['height'] = int(input_dict['height'])
input_dict['width'] = int(input_dict['width'])
input_dict['duration'] = float(input_dict['duration'])
input_dict['digg_count'] = int(input_dict['digg_count'])
input_dict['share_count'] = int(input_dict['share_count'])
input_dict['music_count'] = int(input_dict['music_count'])
input_dict['play_count'] = int(input_dict['play_count'])
input_dict['comment_count'] = int(input_dict['comment_count'])
input_dict['forward_count'] = int(input_dict['forward_count'])
input_dict['download_count'] = int(input_dict['download_count'])
input_dict['favoriting_count'] = int(input_dict['favoriting_count'])
input_dict['follower_count'] = int(input_dict['follower_count'])
input_dict['following_count'] = int(input_dict['following_count'])
input_dict['gender'] = int(input_dict['gender'])
input_dict['has_email'] = bool(input_dict['has_email'])
input_dict['is_mute'] = int(input_dict['is_mute'])
input_dict['mention_status'] = int(input_dict['mention_status'])
input_dict['user_rate'] = int(input_dict['user_rate'])
input_dict['aweme_count'] = int(input_dict['aweme_count'])
input_dict['friends_status'] = int(input_dict['friends_status'])
input_dict['total_favorited'] = int(input_dict['total_favorited'])
input_dict['id_str'] = float(input_dict['id_str'])
# String fields (sec_id, ratio, desc, title, share_title, language, birthday, signature, topic) remain as-is
except (ValueError, TypeError) as e:
error_msg = f"Input error: Failed to cast {key}: {str(e)}"
logger.error(error_msg)
return error_msg, None, None
# Prepare warnings
warnings = []
if default_used:
warnings.append(f"Using default values for: {', '.join(default_used)}")
warnings.append("*Note*: Prediction may be less accurate because scaler parameters from training are not available.")
warning_msg = "\n\n".join(warnings)
try:
df_ = pd.DataFrame([input_dict])
logger.info("TEST DATASET")
logger.info(f"\n{df_}")
preprocessor = Preprocessor(df_)
df_row = preprocessor.run_pipeline()
except Exception as e:
logger.error(f"Preprocessing error: {str(e)}")
return f"Preprocessing error: {str(e)}", None, warning_msg
try:
prob, pred = predict_single_row(df_row)
result = f"**Probability of being fake: {prob:.4f}**\n\n"
result += f"**Predicted Class: {'Fake' if pred == 1 else 'Not Fake'}**"
logger.info(f"Prediction successful: Probability={prob:.4f}, Class={'Fake' if pred == 1 else 'Not Fake'}")
return result, prob, warning_msg
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
return f"Prediction error: {str(e)}", None, warning_msg
# Gradio interface
with gr.Blocks(title="Fake User Predictor") as demo:
gr.Markdown("# Fake User Predictor")
gr.Markdown("Enter user and post details to predict if the user is fake. Leave fields blank to use default values. Note: Scaler parameters are not available, which may affect accuracy.")
with gr.Row():
with gr.Column():
sec_id = gr.Textbox(label="User ID (sec_id)", placeholder=default_values['sec_id'])
create_time = gr.Number(label="Create Time (Unix timestamp)", value=None, info=f"Default: {default_values['create_time']}")
height = gr.Number(label="Video Height", value=None, info=f"Default: {default_values['height']}")
width = gr.Number(label="Video Width", value=None, info=f"Default: {default_values['width']}")
ratio = gr.Textbox(label="Video Ratio", placeholder=default_values['ratio'])
duration = gr.Number(label="Video Duration (seconds)", value=None, info=f"Default: {default_values['duration']}")
digg_count = gr.Number(label="Digg Count", value=None, info=f"Default: {default_values['digg_count']}")
share_count = gr.Number(label="Share Count", value=None, info=f"Default: {default_values['share_count']}")
music_count = gr.Number(label="Music Count", value=None, info=f"Default: {default_values['music_count']}")
play_count = gr.Number(label="Play Count", value=None, info=f"Default: {default_values['play_count']}")
comment_count = gr.Number(label="Comment Count", value=None, info=f"Default: {default_values['comment_count']}")
forward_count = gr.Number(label="Forward Count", value=None, info=f"Default: {default_values['forward_count']}")
download_count = gr.Number(label="Download Count", value=None, info=f"Default: {default_values['download_count']}")
desc = gr.Textbox(label="Post Description", placeholder=default_values['desc'])
title = gr.Textbox(label="Post Title", placeholder=default_values['title'])
share_title = gr.Textbox(label="Share Title", placeholder=default_values['share_title'])
with gr.Column():
favoriting_count = gr.Number(label="Favoriting Count", value=None, info=f"Default: {default_values['favoriting_count']}")
follower_count = gr.Number(label="Follower Count", value=None, info=f"Default: {default_values['follower_count']}")
following_count = gr.Number(label="Following Count", value=None, info=f"Default: {default_values['following_count']}")
gender = gr.Dropdown(label="Gender", choices=["0", "1", "2"], value=None, allow_custom_value=False)
has_email = gr.Checkbox(label="Has Email", value=False)
is_mute = gr.Dropdown(label="Is Mute", choices=["0", "1"], value=None, allow_custom_value=False)
language = gr.Textbox(label="Language", placeholder=default_values['language'])
mention_status = gr.Dropdown(label="Mention Status", choices=["0", "1"], value=None, allow_custom_value=False)
user_rate = gr.Number(label="User Rate", value=None, info=f"Default: {default_values['user_rate']}")
aweme_count = gr.Number(label="Post Count (aweme_count)", value=None, info=f"Default: {default_values['aweme_count']}")
birthday = gr.Textbox(label="Birthday", placeholder=default_values['birthday'])
friends_status = gr.Dropdown(label="Friends Status", choices=["0", "1"], value=None, allow_custom_value=False)
signature = gr.Textbox(label="Signature", placeholder=default_values['signature'])
total_favorited = gr.Number(label="Total Favorited", value=None, info=f"Default: {default_values['total_favorited']}")
id_str = gr.Number(label="ID String", value=None, info=f"Default: {default_values['id_str']}")
topic = gr.Textbox(label="Topic", placeholder=default_values['topic'])
predict_btn = gr.Button("Predict")
output_text = gr.Markdown(label="Prediction Result")
prob_output = gr.Number(label="Probability", visible=False)
warning_output = gr.Markdown(label="Warnings")
predict_btn.click(
fn=predict_fake_user,
inputs=[
sec_id, create_time, height, width, ratio, duration, digg_count,
share_count, music_count, play_count, comment_count, forward_count,
download_count, desc, title, share_title, favoriting_count,
follower_count, following_count, gender, has_email, is_mute,
language, mention_status, user_rate, aweme_count, birthday,
friends_status, signature, total_favorited, id_str, topic
],
outputs=[output_text, prob_output, warning_output]
)
# Launch the interface immediately
if __name__ == "__main__":
demo.launch(share=True)