import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.nn import MessagePassing import pandas as pd import numpy as np from sklearn.preprocessing import StandardScaler import os from datetime import datetime import re from textblob import TextBlob from preprocessing_test import Preprocessor from src.model import EnergyMPNN from loguru import logger from huggingface_hub import HfApi, HfFolder, upload_folder, Repository,hf_hub_download # Default values default_values = { 'sec_id': 'MS4wLjABAAAAUfldJaS79jt92MrYh5qLtoGwq7okyY7wAB...', 'create_time': 1692246654.0, 'height': 720, 'width': 720, 'ratio': '720p', 'duration': 29.0, 'digg_count': 192897, 'share_count': 4180, 'music_count': 0, 'play_count': 2155824, 'comment_count': 2496, 'forward_count': 0, 'download_count': 39, 'desc': '☠️☠️☠️ #aleksandrsorokin #sorokin #ultrarunner...', 'title': 'original sound - musterpoint77', 'share_title': 'Check out Fad.Run’s video! #TikTok >', 'favoriting_count': 1126, 'follower_count': 22741, 'following_count': 142, 'gender': 0, 'has_email': False, 'is_mute': 0, 'language': 'id', 'mention_status': 1, 'user_rate': 1, 'aweme_count': 50, 'birthday': '1900-01-01', 'friends_status': 0, 'signature': 'Run Enthusiast', 'total_favorited': 2051954, 'id_str': 7268144018282580992.0, 'topic': 'diy' } # Expected types based on schema expected_types = { 'sec_id': str, 'create_time': float, 'height': int, 'width': int, 'ratio': str, 'duration': float, 'digg_count': int, 'share_count': int, 'music_count': int, 'play_count': int, 'comment_count': int, 'forward_count': int, 'download_count': int, 'desc': str, 'title': str, 'share_title': str, 'favoriting_count': int, 'follower_count': int, 'following_count': int, 'gender': int, 'has_email': bool, 'is_mute': int, 'language': str, 'mention_status': int, 'user_rate': int, 'aweme_count': int, 'birthday': str, 'friends_status': int, 'signature': str, 'total_favorited': int, 'id_str': float, 'topic': str } # Preprocess single-row DataFrame def preprocess_single_row(df_row, scaler_user=None, scaler_topic=None, scaler_edge=None): required_columns = [ 'sec_id', 'topic', 'create_days_since_creation', 'post_length', 'sentiment_score', 'lexical_diversity', 'create_hour', 'time_since_prev_post', 'lexical_similarity', 'digg_count', 'comment_count', 'share_count' ] if not all(col in df_row.columns for col in required_columns): missing = [col for col in required_columns if col not in df_row.columns] raise ValueError(f"Missing required columns: {missing}") user_features = pd.DataFrame({ 'sec_id': [df_row['sec_id'].iloc[0]], 'create_days_since_creation': [max(df_row['create_days_since_creation'].iloc[0], 1)], 'topic': [1], # Assuming topic is encoded (e.g., 'diy' -> 1) 'post_length': [df_row['post_length'].iloc[0]], 'sentiment_score': [df_row['sentiment_score'].iloc[0]], 'lexical_diversity': [df_row['lexical_diversity'].iloc[0]] }) user_features['posting_frequency'] = 1 / user_features['create_days_since_creation'] user_node_features = user_features[[ 'posting_frequency', 'topic', 'post_length', 'sentiment_score', 'lexical_diversity' ]].values user_node_features = np.hstack([user_node_features, np.zeros((1, 1))]) if scaler_user: user_node_features = scaler_user.transform(user_node_features) else: user_node_features = np.nan_to_num(user_node_features) topic_features = pd.DataFrame({ 'topic': [df_row['topic'].iloc[0]], 'popularity': [1], 'sentiment_mean': [df_row['sentiment_score'].iloc[0]], 'sentiment_var': [0], 'digg_count_mean': [df_row['digg_count'].iloc[0]], 'comment_count_mean': [df_row['comment_count'].iloc[0]], 'share_count_mean': [df_row['share_count'].iloc[0]] }) topic_node_features = topic_features[[ 'popularity', 'sentiment_mean', 'sentiment_var', 'digg_count_mean', 'comment_count_mean', 'share_count_mean' ]].values if scaler_topic: topic_node_features = scaler_topic.transform(topic_node_features) else: topic_node_features = np.nan_to_num(topic_node_features) node_features = np.vstack([user_node_features, topic_node_features]) node_features = torch.tensor(np.nan_to_num(node_features), dtype=torch.float32) edge_columns = ['post_length', 'sentiment_score', 'create_hour', 'time_since_prev_post', 'lexical_similarity'] edge_features = np.array([[df_row[col].iloc[0] for col in edge_columns]]) edge_features = np.repeat(edge_features, 2, axis=0) if scaler_edge: edge_features = scaler_edge.transform(edge_features) else: edge_features = np.nan_to_num(edge_features) edge_features = torch.tensor(edge_features, dtype=torch.float32) edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long).t().contiguous() position_vectors = torch.randn(2, 3) y = torch.tensor([0], dtype=torch.float32) data = Data( x=node_features, edge_index=edge_index, edge_attr=edge_features, y=y, pos=position_vectors ) data.num_users = 1 return data # Predict function def predict_single_row(df_row, model_path='output_files/model_outputs/model_checkpoint/best_model.pth', scaler_user=None, scaler_topic=None, scaler_edge=None): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: data = preprocess_single_row(df_row, scaler_user, scaler_topic, scaler_edge) data = data.to(device) except Exception as e: raise RuntimeError(f"Preprocessing failed: {str(e)}") model = EnergyMPNN( input_node_dim=6, edge_dim=5, hidden_dim=64, pos_dim=3, num_layers=2, dropout=0.2 ).to(device) try: repo_id = "Askhedi/fake_user_detection" model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # model.load_state_dict(torch.load(model_path, map_location=device)) except Exception as e: raise RuntimeError(f"Failed to load model from {model_path}: {str(e)}") model.eval() try: with torch.no_grad(): user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) prob = torch.sigmoid(user_scores).item() metrics_path = 'output_files/model_outputs/test_metrics.csv' threshold = 0.5 if os.path.exists(metrics_path): try: threshold = pd.read_csv(metrics_path)['threshold'].iloc[0] except Exception as e: logger.warning(f"Could not load threshold: {str(e)}. Using default 0.5.") pred = 1 if prob > threshold else 0 except Exception as e: raise RuntimeError(f"Prediction failed: {str(e)}") return prob, pred # Gradio prediction function def predict_fake_user( sec_id, create_time, height, width, ratio, duration, digg_count, share_count, music_count, play_count, comment_count, forward_count, download_count, desc, title, share_title, favoriting_count, follower_count, following_count, gender, has_email, is_mute, language, mention_status, user_rate, aweme_count, birthday, friends_status, signature, total_favorited, id_str, topic ): input_dict = {} inputs = { 'sec_id': sec_id, 'create_time': create_time, 'height': height, 'width': width, 'ratio': ratio, 'duration': duration, 'digg_count': digg_count, 'share_count': share_count, 'music_count': music_count, 'play_count': play_count, 'comment_count': comment_count, 'forward_count': forward_count, 'download_count': download_count, 'desc': desc, 'title': title, 'share_title': share_title, 'favoriting_count': favoriting_count, 'follower_count': follower_count, 'following_count': following_count, 'gender': gender, 'has_email': has_email, 'is_mute': is_mute, 'language': language, 'mention_status': mention_status, 'user_rate': user_rate, 'aweme_count': aweme_count, 'birthday': birthday, 'friends_status': friends_status, 'signature': signature, 'total_favorited': total_favorited, 'id_str': id_str, 'topic': topic } default_used = [] type_errors = [] # Validate types and track defaults for key, value in inputs.items(): # Check if value is missing (use default) if value is None or value == "" or (isinstance(value, float) and np.isnan(value)): input_dict[key] = default_values[key] default_used.append(key) else: input_dict[key] = value expected_type = expected_types[key] try: # Validate type if expected_type == str: if not isinstance(value, str): type_errors.append(f"'{key}' has value '{value}', expected string") elif expected_type == int: # For Dropdown inputs, value may be str ("0", "1") if isinstance(value, str) and value.isdigit(): value = int(value) # Convert valid string to int elif not isinstance(value, (int, float)) or (isinstance(value, float) and not value.is_integer()): type_errors.append(f"'{key}' has value '{value}', expected integer") elif expected_type == float: if isinstance(value, str): float(value) # Try converting string to float elif not isinstance(value, (int, float)): type_errors.append(f"'{key}' has value '{value}', expected float") elif expected_type == bool: if not isinstance(value, bool): type_errors.append(f"'{key}' has value '{value}', expected boolean") except (ValueError, TypeError): type_errors.append(f"'{key}' has value '{value}', expected {expected_type.__name__}") # Return type errors if any if type_errors: error_msg = "Input errors:\n" + "\n".join(f"- {err}" for err in type_errors) logger.error(error_msg) return error_msg, None, None # Cast inputs to correct types try: input_dict['create_time'] = float(input_dict['create_time']) input_dict['height'] = int(input_dict['height']) input_dict['width'] = int(input_dict['width']) input_dict['duration'] = float(input_dict['duration']) input_dict['digg_count'] = int(input_dict['digg_count']) input_dict['share_count'] = int(input_dict['share_count']) input_dict['music_count'] = int(input_dict['music_count']) input_dict['play_count'] = int(input_dict['play_count']) input_dict['comment_count'] = int(input_dict['comment_count']) input_dict['forward_count'] = int(input_dict['forward_count']) input_dict['download_count'] = int(input_dict['download_count']) input_dict['favoriting_count'] = int(input_dict['favoriting_count']) input_dict['follower_count'] = int(input_dict['follower_count']) input_dict['following_count'] = int(input_dict['following_count']) input_dict['gender'] = int(input_dict['gender']) input_dict['has_email'] = bool(input_dict['has_email']) input_dict['is_mute'] = int(input_dict['is_mute']) input_dict['mention_status'] = int(input_dict['mention_status']) input_dict['user_rate'] = int(input_dict['user_rate']) input_dict['aweme_count'] = int(input_dict['aweme_count']) input_dict['friends_status'] = int(input_dict['friends_status']) input_dict['total_favorited'] = int(input_dict['total_favorited']) input_dict['id_str'] = float(input_dict['id_str']) # String fields (sec_id, ratio, desc, title, share_title, language, birthday, signature, topic) remain as-is except (ValueError, TypeError) as e: error_msg = f"Input error: Failed to cast {key}: {str(e)}" logger.error(error_msg) return error_msg, None, None # Prepare warnings warnings = [] if default_used: warnings.append(f"Using default values for: {', '.join(default_used)}") warnings.append("*Note*: Prediction may be less accurate because scaler parameters from training are not available.") warning_msg = "\n\n".join(warnings) try: df_ = pd.DataFrame([input_dict]) logger.info("TEST DATASET") logger.info(f"\n{df_}") preprocessor = Preprocessor(df_) df_row = preprocessor.run_pipeline() except Exception as e: logger.error(f"Preprocessing error: {str(e)}") return f"Preprocessing error: {str(e)}", None, warning_msg try: prob, pred = predict_single_row(df_row) result = f"**Probability of being fake: {prob:.4f}**\n\n" result += f"**Predicted Class: {'Fake' if pred == 1 else 'Not Fake'}**" logger.info(f"Prediction successful: Probability={prob:.4f}, Class={'Fake' if pred == 1 else 'Not Fake'}") return result, prob, warning_msg except Exception as e: logger.error(f"Prediction error: {str(e)}") return f"Prediction error: {str(e)}", None, warning_msg # Gradio interface with gr.Blocks(title="Fake User Predictor") as demo: gr.Markdown("# Fake User Predictor") gr.Markdown("Enter user and post details to predict if the user is fake. Leave fields blank to use default values. Note: Scaler parameters are not available, which may affect accuracy.") with gr.Row(): with gr.Column(): sec_id = gr.Textbox(label="User ID (sec_id)", placeholder=default_values['sec_id']) create_time = gr.Number(label="Create Time (Unix timestamp)", value=None, info=f"Default: {default_values['create_time']}") height = gr.Number(label="Video Height", value=None, info=f"Default: {default_values['height']}") width = gr.Number(label="Video Width", value=None, info=f"Default: {default_values['width']}") ratio = gr.Textbox(label="Video Ratio", placeholder=default_values['ratio']) duration = gr.Number(label="Video Duration (seconds)", value=None, info=f"Default: {default_values['duration']}") digg_count = gr.Number(label="Digg Count", value=None, info=f"Default: {default_values['digg_count']}") share_count = gr.Number(label="Share Count", value=None, info=f"Default: {default_values['share_count']}") music_count = gr.Number(label="Music Count", value=None, info=f"Default: {default_values['music_count']}") play_count = gr.Number(label="Play Count", value=None, info=f"Default: {default_values['play_count']}") comment_count = gr.Number(label="Comment Count", value=None, info=f"Default: {default_values['comment_count']}") forward_count = gr.Number(label="Forward Count", value=None, info=f"Default: {default_values['forward_count']}") download_count = gr.Number(label="Download Count", value=None, info=f"Default: {default_values['download_count']}") desc = gr.Textbox(label="Post Description", placeholder=default_values['desc']) title = gr.Textbox(label="Post Title", placeholder=default_values['title']) share_title = gr.Textbox(label="Share Title", placeholder=default_values['share_title']) with gr.Column(): favoriting_count = gr.Number(label="Favoriting Count", value=None, info=f"Default: {default_values['favoriting_count']}") follower_count = gr.Number(label="Follower Count", value=None, info=f"Default: {default_values['follower_count']}") following_count = gr.Number(label="Following Count", value=None, info=f"Default: {default_values['following_count']}") gender = gr.Dropdown(label="Gender", choices=["0", "1", "2"], value=None, allow_custom_value=False) has_email = gr.Checkbox(label="Has Email", value=False) is_mute = gr.Dropdown(label="Is Mute", choices=["0", "1"], value=None, allow_custom_value=False) language = gr.Textbox(label="Language", placeholder=default_values['language']) mention_status = gr.Dropdown(label="Mention Status", choices=["0", "1"], value=None, allow_custom_value=False) user_rate = gr.Number(label="User Rate", value=None, info=f"Default: {default_values['user_rate']}") aweme_count = gr.Number(label="Post Count (aweme_count)", value=None, info=f"Default: {default_values['aweme_count']}") birthday = gr.Textbox(label="Birthday", placeholder=default_values['birthday']) friends_status = gr.Dropdown(label="Friends Status", choices=["0", "1"], value=None, allow_custom_value=False) signature = gr.Textbox(label="Signature", placeholder=default_values['signature']) total_favorited = gr.Number(label="Total Favorited", value=None, info=f"Default: {default_values['total_favorited']}") id_str = gr.Number(label="ID String", value=None, info=f"Default: {default_values['id_str']}") topic = gr.Textbox(label="Topic", placeholder=default_values['topic']) predict_btn = gr.Button("Predict") output_text = gr.Markdown(label="Prediction Result") prob_output = gr.Number(label="Probability", visible=False) warning_output = gr.Markdown(label="Warnings") predict_btn.click( fn=predict_fake_user, inputs=[ sec_id, create_time, height, width, ratio, duration, digg_count, share_count, music_count, play_count, comment_count, forward_count, download_count, desc, title, share_title, favoriting_count, follower_count, following_count, gender, has_email, is_mute, language, mention_status, user_rate, aweme_count, birthday, friends_status, signature, total_favorited, id_str, topic ], outputs=[output_text, prob_output, warning_output] ) # Launch the interface immediately if __name__ == "__main__": demo.launch(share=True)