Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
from sklearn.model_selection import train_test_split
|
| 4 |
from sklearn.linear_model import LinearRegression, LogisticRegression
|
| 5 |
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
|
|
@@ -19,13 +20,30 @@ import logging
|
|
| 19 |
# Set up logging
|
| 20 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 21 |
|
| 22 |
-
# Load data
|
| 23 |
-
logging.info("Loading data...")
|
| 24 |
data = pd.read_csv('train_data.csv')
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Feature Engineering
|
| 27 |
logging.info("Performing feature engineering...")
|
| 28 |
-
data['posting_time_encoded'] = pd.to_datetime(data['posting_time']).astype(int) / 10**9
|
| 29 |
data['caption_length'] = data['caption'].apply(len)
|
| 30 |
data['hashtag_count'] = data['hashtags'].apply(lambda x: len(eval(x)))
|
| 31 |
data['viral'] = data['engagement_rate'].apply(lambda x: 1 if x > data['engagement_rate'].quantile(0.75) else 0)
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import numpy as np
|
| 3 |
+
import json
|
| 4 |
from sklearn.model_selection import train_test_split
|
| 5 |
from sklearn.linear_model import LinearRegression, LogisticRegression
|
| 6 |
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
|
|
|
|
| 20 |
# Set up logging
|
| 21 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 22 |
|
| 23 |
+
# Load Instagram data
|
| 24 |
+
logging.info("Loading Instagram data...")
|
| 25 |
data = pd.read_csv('train_data.csv')
|
| 26 |
|
| 27 |
+
# Debug: Inspect the posting_time column
|
| 28 |
+
logging.info("Inspecting posting_time column...")
|
| 29 |
+
print(data['posting_time'].head())
|
| 30 |
+
|
| 31 |
+
# Parse the posting_time column
|
| 32 |
+
logging.info("Parsing posting_time column...")
|
| 33 |
+
data['posting_time'] = pd.to_datetime(data['posting_time'], format='%Y-%m-%d %H:%M:%S', errors='coerce')
|
| 34 |
+
|
| 35 |
+
# Check for NaT values (invalid datetime entries)
|
| 36 |
+
if data['posting_time'].isna().any():
|
| 37 |
+
logging.warning(f"Found {data['posting_time'].isna().sum()} invalid datetime entries. They will be set to NaT.")
|
| 38 |
+
|
| 39 |
+
# Convert to Unix timestamp
|
| 40 |
+
logging.info("Converting posting_time to Unix timestamp...")
|
| 41 |
+
data['posting_time_encoded'] = data['posting_time'].astype(int) / 10**9
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
# Feature Engineering
|
| 46 |
logging.info("Performing feature engineering...")
|
|
|
|
| 47 |
data['caption_length'] = data['caption'].apply(len)
|
| 48 |
data['hashtag_count'] = data['hashtags'].apply(lambda x: len(eval(x)))
|
| 49 |
data['viral'] = data['engagement_rate'].apply(lambda x: 1 if x > data['engagement_rate'].quantile(0.75) else 0)
|