Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.data import Data | |
| from torch_geometric.nn import MessagePassing | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.preprocessing import StandardScaler | |
| import os | |
| from datetime import datetime | |
| import re | |
| from textblob import TextBlob | |
| from preprocessing_test import Preprocessor | |
| from src.model import EnergyMPNN | |
| from loguru import logger | |
| from huggingface_hub import HfApi, HfFolder, upload_folder, Repository,hf_hub_download | |
| # Default values | |
| default_values = { | |
| 'sec_id': 'MS4wLjABAAAAUfldJaS79jt92MrYh5qLtoGwq7okyY7wAB...', | |
| 'create_time': 1692246654.0, | |
| 'height': 720, | |
| 'width': 720, | |
| 'ratio': '720p', | |
| 'duration': 29.0, | |
| 'digg_count': 192897, | |
| 'share_count': 4180, | |
| 'music_count': 0, | |
| 'play_count': 2155824, | |
| 'comment_count': 2496, | |
| 'forward_count': 0, | |
| 'download_count': 39, | |
| 'desc': '☠️☠️☠️ #aleksandrsorokin #sorokin #ultrarunner...', | |
| 'title': 'original sound - musterpoint77', | |
| 'share_title': 'Check out Fad.Run’s video! #TikTok >', | |
| 'favoriting_count': 1126, | |
| 'follower_count': 22741, | |
| 'following_count': 142, | |
| 'gender': 0, | |
| 'has_email': False, | |
| 'is_mute': 0, | |
| 'language': 'id', | |
| 'mention_status': 1, | |
| 'user_rate': 1, | |
| 'aweme_count': 50, | |
| 'birthday': '1900-01-01', | |
| 'friends_status': 0, | |
| 'signature': 'Run Enthusiast', | |
| 'total_favorited': 2051954, | |
| 'id_str': 7268144018282580992.0, | |
| 'topic': 'diy' | |
| } | |
| # Expected types based on schema | |
| expected_types = { | |
| 'sec_id': str, | |
| 'create_time': float, | |
| 'height': int, | |
| 'width': int, | |
| 'ratio': str, | |
| 'duration': float, | |
| 'digg_count': int, | |
| 'share_count': int, | |
| 'music_count': int, | |
| 'play_count': int, | |
| 'comment_count': int, | |
| 'forward_count': int, | |
| 'download_count': int, | |
| 'desc': str, | |
| 'title': str, | |
| 'share_title': str, | |
| 'favoriting_count': int, | |
| 'follower_count': int, | |
| 'following_count': int, | |
| 'gender': int, | |
| 'has_email': bool, | |
| 'is_mute': int, | |
| 'language': str, | |
| 'mention_status': int, | |
| 'user_rate': int, | |
| 'aweme_count': int, | |
| 'birthday': str, | |
| 'friends_status': int, | |
| 'signature': str, | |
| 'total_favorited': int, | |
| 'id_str': float, | |
| 'topic': str | |
| } | |
| # Preprocess single-row DataFrame | |
| def preprocess_single_row(df_row, scaler_user=None, scaler_topic=None, scaler_edge=None): | |
| required_columns = [ | |
| 'sec_id', 'topic', 'create_days_since_creation', 'post_length', | |
| 'sentiment_score', 'lexical_diversity', 'create_hour', | |
| 'time_since_prev_post', 'lexical_similarity', 'digg_count', | |
| 'comment_count', 'share_count' | |
| ] | |
| if not all(col in df_row.columns for col in required_columns): | |
| missing = [col for col in required_columns if col not in df_row.columns] | |
| raise ValueError(f"Missing required columns: {missing}") | |
| user_features = pd.DataFrame({ | |
| 'sec_id': [df_row['sec_id'].iloc[0]], | |
| 'create_days_since_creation': [max(df_row['create_days_since_creation'].iloc[0], 1)], | |
| 'topic': [1], # Assuming topic is encoded (e.g., 'diy' -> 1) | |
| 'post_length': [df_row['post_length'].iloc[0]], | |
| 'sentiment_score': [df_row['sentiment_score'].iloc[0]], | |
| 'lexical_diversity': [df_row['lexical_diversity'].iloc[0]] | |
| }) | |
| user_features['posting_frequency'] = 1 / user_features['create_days_since_creation'] | |
| user_node_features = user_features[[ | |
| 'posting_frequency', 'topic', 'post_length', | |
| 'sentiment_score', 'lexical_diversity' | |
| ]].values | |
| user_node_features = np.hstack([user_node_features, np.zeros((1, 1))]) | |
| if scaler_user: | |
| user_node_features = scaler_user.transform(user_node_features) | |
| else: | |
| user_node_features = np.nan_to_num(user_node_features) | |
| topic_features = pd.DataFrame({ | |
| 'topic': [df_row['topic'].iloc[0]], | |
| 'popularity': [1], | |
| 'sentiment_mean': [df_row['sentiment_score'].iloc[0]], | |
| 'sentiment_var': [0], | |
| 'digg_count_mean': [df_row['digg_count'].iloc[0]], | |
| 'comment_count_mean': [df_row['comment_count'].iloc[0]], | |
| 'share_count_mean': [df_row['share_count'].iloc[0]] | |
| }) | |
| topic_node_features = topic_features[[ | |
| 'popularity', 'sentiment_mean', 'sentiment_var', | |
| 'digg_count_mean', 'comment_count_mean', 'share_count_mean' | |
| ]].values | |
| if scaler_topic: | |
| topic_node_features = scaler_topic.transform(topic_node_features) | |
| else: | |
| topic_node_features = np.nan_to_num(topic_node_features) | |
| node_features = np.vstack([user_node_features, topic_node_features]) | |
| node_features = torch.tensor(np.nan_to_num(node_features), dtype=torch.float32) | |
| edge_columns = ['post_length', 'sentiment_score', 'create_hour', 'time_since_prev_post', 'lexical_similarity'] | |
| edge_features = np.array([[df_row[col].iloc[0] for col in edge_columns]]) | |
| edge_features = np.repeat(edge_features, 2, axis=0) | |
| if scaler_edge: | |
| edge_features = scaler_edge.transform(edge_features) | |
| else: | |
| edge_features = np.nan_to_num(edge_features) | |
| edge_features = torch.tensor(edge_features, dtype=torch.float32) | |
| edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long).t().contiguous() | |
| position_vectors = torch.randn(2, 3) | |
| y = torch.tensor([0], dtype=torch.float32) | |
| data = Data( | |
| x=node_features, | |
| edge_index=edge_index, | |
| edge_attr=edge_features, | |
| y=y, | |
| pos=position_vectors | |
| ) | |
| data.num_users = 1 | |
| return data | |
| # Predict function | |
| def predict_single_row(df_row, model_path='output_files/model_outputs/model_checkpoint/best_model.pth', scaler_user=None, scaler_topic=None, scaler_edge=None): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| try: | |
| data = preprocess_single_row(df_row, scaler_user, scaler_topic, scaler_edge) | |
| data = data.to(device) | |
| except Exception as e: | |
| raise RuntimeError(f"Preprocessing failed: {str(e)}") | |
| model = EnergyMPNN( | |
| input_node_dim=6, | |
| edge_dim=5, | |
| hidden_dim=64, | |
| pos_dim=3, | |
| num_layers=2, | |
| dropout=0.2 | |
| ).to(device) | |
| try: | |
| repo_id = "Askhedi/fake_user_detection" | |
| model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| # model.load_state_dict(torch.load(model_path, map_location=device)) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model from {model_path}: {str(e)}") | |
| model.eval() | |
| try: | |
| with torch.no_grad(): | |
| user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) | |
| prob = torch.sigmoid(user_scores).item() | |
| metrics_path = 'output_files/model_outputs/test_metrics.csv' | |
| threshold = 0.5 | |
| if os.path.exists(metrics_path): | |
| try: | |
| threshold = pd.read_csv(metrics_path)['threshold'].iloc[0] | |
| except Exception as e: | |
| logger.warning(f"Could not load threshold: {str(e)}. Using default 0.5.") | |
| pred = 1 if prob > threshold else 0 | |
| except Exception as e: | |
| raise RuntimeError(f"Prediction failed: {str(e)}") | |
| return prob, pred | |
| # Gradio prediction function | |
| def predict_fake_user( | |
| sec_id, create_time, height, width, ratio, duration, digg_count, | |
| share_count, music_count, play_count, comment_count, forward_count, | |
| download_count, desc, title, share_title, favoriting_count, | |
| follower_count, following_count, gender, has_email, is_mute, | |
| language, mention_status, user_rate, aweme_count, birthday, | |
| friends_status, signature, total_favorited, id_str, topic | |
| ): | |
| input_dict = {} | |
| inputs = { | |
| 'sec_id': sec_id, | |
| 'create_time': create_time, | |
| 'height': height, | |
| 'width': width, | |
| 'ratio': ratio, | |
| 'duration': duration, | |
| 'digg_count': digg_count, | |
| 'share_count': share_count, | |
| 'music_count': music_count, | |
| 'play_count': play_count, | |
| 'comment_count': comment_count, | |
| 'forward_count': forward_count, | |
| 'download_count': download_count, | |
| 'desc': desc, | |
| 'title': title, | |
| 'share_title': share_title, | |
| 'favoriting_count': favoriting_count, | |
| 'follower_count': follower_count, | |
| 'following_count': following_count, | |
| 'gender': gender, | |
| 'has_email': has_email, | |
| 'is_mute': is_mute, | |
| 'language': language, | |
| 'mention_status': mention_status, | |
| 'user_rate': user_rate, | |
| 'aweme_count': aweme_count, | |
| 'birthday': birthday, | |
| 'friends_status': friends_status, | |
| 'signature': signature, | |
| 'total_favorited': total_favorited, | |
| 'id_str': id_str, | |
| 'topic': topic | |
| } | |
| default_used = [] | |
| type_errors = [] | |
| # Validate types and track defaults | |
| for key, value in inputs.items(): | |
| # Check if value is missing (use default) | |
| if value is None or value == "" or (isinstance(value, float) and np.isnan(value)): | |
| input_dict[key] = default_values[key] | |
| default_used.append(key) | |
| else: | |
| input_dict[key] = value | |
| expected_type = expected_types[key] | |
| try: | |
| # Validate type | |
| if expected_type == str: | |
| if not isinstance(value, str): | |
| type_errors.append(f"'{key}' has value '{value}', expected string") | |
| elif expected_type == int: | |
| # For Dropdown inputs, value may be str ("0", "1") | |
| if isinstance(value, str) and value.isdigit(): | |
| value = int(value) # Convert valid string to int | |
| elif not isinstance(value, (int, float)) or (isinstance(value, float) and not value.is_integer()): | |
| type_errors.append(f"'{key}' has value '{value}', expected integer") | |
| elif expected_type == float: | |
| if isinstance(value, str): | |
| float(value) # Try converting string to float | |
| elif not isinstance(value, (int, float)): | |
| type_errors.append(f"'{key}' has value '{value}', expected float") | |
| elif expected_type == bool: | |
| if not isinstance(value, bool): | |
| type_errors.append(f"'{key}' has value '{value}', expected boolean") | |
| except (ValueError, TypeError): | |
| type_errors.append(f"'{key}' has value '{value}', expected {expected_type.__name__}") | |
| # Return type errors if any | |
| if type_errors: | |
| error_msg = "Input errors:\n" + "\n".join(f"- {err}" for err in type_errors) | |
| logger.error(error_msg) | |
| return error_msg, None, None | |
| # Cast inputs to correct types | |
| try: | |
| input_dict['create_time'] = float(input_dict['create_time']) | |
| input_dict['height'] = int(input_dict['height']) | |
| input_dict['width'] = int(input_dict['width']) | |
| input_dict['duration'] = float(input_dict['duration']) | |
| input_dict['digg_count'] = int(input_dict['digg_count']) | |
| input_dict['share_count'] = int(input_dict['share_count']) | |
| input_dict['music_count'] = int(input_dict['music_count']) | |
| input_dict['play_count'] = int(input_dict['play_count']) | |
| input_dict['comment_count'] = int(input_dict['comment_count']) | |
| input_dict['forward_count'] = int(input_dict['forward_count']) | |
| input_dict['download_count'] = int(input_dict['download_count']) | |
| input_dict['favoriting_count'] = int(input_dict['favoriting_count']) | |
| input_dict['follower_count'] = int(input_dict['follower_count']) | |
| input_dict['following_count'] = int(input_dict['following_count']) | |
| input_dict['gender'] = int(input_dict['gender']) | |
| input_dict['has_email'] = bool(input_dict['has_email']) | |
| input_dict['is_mute'] = int(input_dict['is_mute']) | |
| input_dict['mention_status'] = int(input_dict['mention_status']) | |
| input_dict['user_rate'] = int(input_dict['user_rate']) | |
| input_dict['aweme_count'] = int(input_dict['aweme_count']) | |
| input_dict['friends_status'] = int(input_dict['friends_status']) | |
| input_dict['total_favorited'] = int(input_dict['total_favorited']) | |
| input_dict['id_str'] = float(input_dict['id_str']) | |
| # String fields (sec_id, ratio, desc, title, share_title, language, birthday, signature, topic) remain as-is | |
| except (ValueError, TypeError) as e: | |
| error_msg = f"Input error: Failed to cast {key}: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg, None, None | |
| # Prepare warnings | |
| warnings = [] | |
| if default_used: | |
| warnings.append(f"Using default values for: {', '.join(default_used)}") | |
| warnings.append("*Note*: Prediction may be less accurate because scaler parameters from training are not available.") | |
| warning_msg = "\n\n".join(warnings) | |
| try: | |
| df_ = pd.DataFrame([input_dict]) | |
| logger.info("TEST DATASET") | |
| logger.info(f"\n{df_}") | |
| preprocessor = Preprocessor(df_) | |
| df_row = preprocessor.run_pipeline() | |
| except Exception as e: | |
| logger.error(f"Preprocessing error: {str(e)}") | |
| return f"Preprocessing error: {str(e)}", None, warning_msg | |
| try: | |
| prob, pred = predict_single_row(df_row) | |
| result = f"**Probability of being fake: {prob:.4f}**\n\n" | |
| result += f"**Predicted Class: {'Fake' if pred == 1 else 'Not Fake'}**" | |
| logger.info(f"Prediction successful: Probability={prob:.4f}, Class={'Fake' if pred == 1 else 'Not Fake'}") | |
| return result, prob, warning_msg | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| return f"Prediction error: {str(e)}", None, warning_msg | |
| # Gradio interface | |
| with gr.Blocks(title="Fake User Predictor") as demo: | |
| gr.Markdown("# Fake User Predictor") | |
| gr.Markdown("Enter user and post details to predict if the user is fake. Leave fields blank to use default values. Note: Scaler parameters are not available, which may affect accuracy.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sec_id = gr.Textbox(label="User ID (sec_id)", placeholder=default_values['sec_id']) | |
| create_time = gr.Number(label="Create Time (Unix timestamp)", value=None, info=f"Default: {default_values['create_time']}") | |
| height = gr.Number(label="Video Height", value=None, info=f"Default: {default_values['height']}") | |
| width = gr.Number(label="Video Width", value=None, info=f"Default: {default_values['width']}") | |
| ratio = gr.Textbox(label="Video Ratio", placeholder=default_values['ratio']) | |
| duration = gr.Number(label="Video Duration (seconds)", value=None, info=f"Default: {default_values['duration']}") | |
| digg_count = gr.Number(label="Digg Count", value=None, info=f"Default: {default_values['digg_count']}") | |
| share_count = gr.Number(label="Share Count", value=None, info=f"Default: {default_values['share_count']}") | |
| music_count = gr.Number(label="Music Count", value=None, info=f"Default: {default_values['music_count']}") | |
| play_count = gr.Number(label="Play Count", value=None, info=f"Default: {default_values['play_count']}") | |
| comment_count = gr.Number(label="Comment Count", value=None, info=f"Default: {default_values['comment_count']}") | |
| forward_count = gr.Number(label="Forward Count", value=None, info=f"Default: {default_values['forward_count']}") | |
| download_count = gr.Number(label="Download Count", value=None, info=f"Default: {default_values['download_count']}") | |
| desc = gr.Textbox(label="Post Description", placeholder=default_values['desc']) | |
| title = gr.Textbox(label="Post Title", placeholder=default_values['title']) | |
| share_title = gr.Textbox(label="Share Title", placeholder=default_values['share_title']) | |
| with gr.Column(): | |
| favoriting_count = gr.Number(label="Favoriting Count", value=None, info=f"Default: {default_values['favoriting_count']}") | |
| follower_count = gr.Number(label="Follower Count", value=None, info=f"Default: {default_values['follower_count']}") | |
| following_count = gr.Number(label="Following Count", value=None, info=f"Default: {default_values['following_count']}") | |
| gender = gr.Dropdown(label="Gender", choices=["0", "1", "2"], value=None, allow_custom_value=False) | |
| has_email = gr.Checkbox(label="Has Email", value=False) | |
| is_mute = gr.Dropdown(label="Is Mute", choices=["0", "1"], value=None, allow_custom_value=False) | |
| language = gr.Textbox(label="Language", placeholder=default_values['language']) | |
| mention_status = gr.Dropdown(label="Mention Status", choices=["0", "1"], value=None, allow_custom_value=False) | |
| user_rate = gr.Number(label="User Rate", value=None, info=f"Default: {default_values['user_rate']}") | |
| aweme_count = gr.Number(label="Post Count (aweme_count)", value=None, info=f"Default: {default_values['aweme_count']}") | |
| birthday = gr.Textbox(label="Birthday", placeholder=default_values['birthday']) | |
| friends_status = gr.Dropdown(label="Friends Status", choices=["0", "1"], value=None, allow_custom_value=False) | |
| signature = gr.Textbox(label="Signature", placeholder=default_values['signature']) | |
| total_favorited = gr.Number(label="Total Favorited", value=None, info=f"Default: {default_values['total_favorited']}") | |
| id_str = gr.Number(label="ID String", value=None, info=f"Default: {default_values['id_str']}") | |
| topic = gr.Textbox(label="Topic", placeholder=default_values['topic']) | |
| predict_btn = gr.Button("Predict") | |
| output_text = gr.Markdown(label="Prediction Result") | |
| prob_output = gr.Number(label="Probability", visible=False) | |
| warning_output = gr.Markdown(label="Warnings") | |
| predict_btn.click( | |
| fn=predict_fake_user, | |
| inputs=[ | |
| sec_id, create_time, height, width, ratio, duration, digg_count, | |
| share_count, music_count, play_count, comment_count, forward_count, | |
| download_count, desc, title, share_title, favoriting_count, | |
| follower_count, following_count, gender, has_email, is_mute, | |
| language, mention_status, user_rate, aweme_count, birthday, | |
| friends_status, signature, total_favorited, id_str, topic | |
| ], | |
| outputs=[output_text, prob_output, warning_output] | |
| ) | |
| # Launch the interface immediately | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |