Insta-AI / app.py
Fred808's picture
Update app.py
eb28e94 verified
raw
history blame
12.1 kB
import os
import pandas as pd
import numpy as np
import json
import logging
import re
import requests
from io import BytesIO
from PIL import Image
import pytesseract
from textblob import TextBlob
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_absolute_error, accuracy_score
from sklearn.preprocessing import LabelEncoder
import torch
from torchvision import models, transforms
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from torchvision.models import ResNet50_Weights
import pickle
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Set the working directory to a writable location
WORKING_DIR = "/app" # Use /app or /tmp
if not os.path.exists(WORKING_DIR):
os.makedirs(WORKING_DIR)
os.chdir(WORKING_DIR)
# Verify the current directory
logging.info(f"Current working directory: {os.getcwd()}")
# Cache file to store extracted text
CACHE_FILE = os.path.join(WORKING_DIR, "image_text_cache.pkl")
# Load cache if it exists
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, "rb") as f:
cache = pickle.load(f)
else:
cache = {}
# Function to resize images
def resize_image(image, max_size=(800, 600)):
"""Resize an image to the specified maximum size."""
image.thumbnail(max_size)
return image
# Function to extract text from an image
def extract_text_from_image(image_url):
"""Extract text from an image using OCR."""
if image_url in cache:
return cache[image_url] # Return cached text if available
if not image_url or not isinstance(image_url, str) or not image_url.startswith(('http://', 'https://')):
return "" # Skip invalid URLs
try:
response = requests.get(image_url)
response.raise_for_status() # Raise an error for bad responses (4xx, 5xx)
image = Image.open(BytesIO(response.content))
image = resize_image(image) # Resize the image
text = pytesseract.image_to_string(image)
cache[image_url] = text # Cache the extracted text
return text
except Exception as e:
logging.error(f"Error extracting text from image: {e}")
return ""
# Function to process new images
def process_new_images(new_images):
"""Process new images and return a DataFrame with extracted text."""
new_images['image_text'] = new_images['image_url'].apply(extract_text_from_image)
return new_images
# Load engagement_metrics.json (your company's data)
logging.info("Loading your company's engagement metrics...")
try:
with open('engagement_metrics.json', 'r') as f:
engagement_metrics = json.load(f)
your_df = pd.json_normalize(engagement_metrics)
except FileNotFoundError:
logging.error("engagement_metrics.json not found. Please ensure the file exists.")
exit()
# Load solved.json (your company's hashtags and captions)
logging.info("Loading your company's solved data...")
try:
with open('solved.json', 'r') as f:
solved_data = json.load(f)
your_solved_df = pd.json_normalize(solved_data)
except FileNotFoundError:
logging.error("solved.json not found. Please ensure the file exists.")
exit()
# Load competitor data from JSON
logging.info("Loading competitor data from JSON...")
try:
with open('competitors_data.json', 'r') as f:
competitor_data = json.load(f)
# Combine all competitors' data into a single DataFrame
competitor_dfs = []
for competitor, posts in competitor_data.items():
df = pd.json_normalize(posts)
df['competitor'] = competitor # Add competitor name for tracking
competitor_dfs.append(df)
competitor_df = pd.concat(competitor_dfs, ignore_index=True)
except FileNotFoundError:
logging.error("competitors_data.json not found. Please ensure the file exists.")
exit()
# Ensure required columns exist in your company's data
required_columns = ['likes', 'comments', 'shares', 'posting_time', 'caption', 'hashtags']
missing_columns = [col for col in required_columns if col not in your_df.columns]
if missing_columns:
logging.warning(f"Missing required columns in your company's data: {missing_columns}")
for col in missing_columns:
if col in ['likes', 'comments', 'shares']:
your_df[col] = 0 # Fill with default value (integer)
elif col == 'caption':
your_df[col] = '' # Fill with default value (empty string)
elif col == 'hashtags':
your_df[col] = [[] for _ in range(len(your_df))] # Fill with default value (list of empty lists)
logging.info("Default values added for missing columns.")
# Ensure required columns exist in competitor data
required_columns = ['caption', 'hashtags', 'likes', 'comments', 'date', 'image_url']
missing_columns = [col for col in required_columns if col not in competitor_df.columns]
if missing_columns:
logging.warning(f"Missing required columns in competitor data: {missing_columns}")
for col in missing_columns:
if col == 'caption':
competitor_df[col] = '' # Fill with default value (empty string)
elif col == 'hashtags':
competitor_df[col] = [[] for _ in range(len(competitor_df))] # Fill with default value (list of empty lists)
elif col == 'image_url':
competitor_df[col] = '' # Fill with default value (empty string)
else:
competitor_df[col] = 0 # Fill with default value (integer)
logging.info("Default values added for missing columns.")
# Process your company's data
logging.info("Processing your company's data...")
your_df['posting_time'] = pd.to_datetime(your_df['posting_time'], format='%Y-%m-%d %H:%M:%S', errors='coerce')
your_df = your_df[your_df['posting_time'].notna()]
your_df['engagement_rate'] = your_df['likes'] + your_df['comments'] + your_df['shares']
your_df['caption_length'] = your_df['caption'].apply(len)
your_df['hashtag_count'] = your_df['hashtags'].apply(len)
your_df['caption_sentiment'] = your_df['caption'].apply(lambda x: TextBlob(x).sentiment.polarity)
your_df['sentiment'] = your_df['caption_sentiment']
# Process competitor data
logging.info("Processing competitor data...")
competitor_df['posting_time'] = pd.to_datetime(competitor_df['date'], format='%Y-%m-%d %H:%M:%S', errors='coerce')
competitor_df = competitor_df[competitor_df['posting_time'].notna()]
competitor_df['engagement_rate'] = competitor_df['likes'] + competitor_df['comments']
competitor_df['caption_length'] = competitor_df['caption'].apply(len)
competitor_df['hashtag_count'] = competitor_df['hashtags'].apply(len)
competitor_df['caption_sentiment'] = competitor_df['caption'].apply(lambda x: TextBlob(x).sentiment.polarity)
competitor_df['sentiment'] = competitor_df['caption_sentiment']
# Combine your company's data and competitor data for model training
logging.info("Combining your company's data and competitor data for model training...")
combined_df = pd.concat([your_df, competitor_df], ignore_index=True)
# Handle missing or invalid image URLs
combined_df['image_url'] = combined_df['image_url'].fillna('')
# Process only new images (those not in the cache)
logging.info("Extracting text from new images...")
new_images = combined_df[~combined_df['image_url'].isin(cache.keys())]
new_images = process_new_images(new_images)
# Update the combined DataFrame with extracted text
combined_df.update(new_images)
# Save the updated dataset
combined_df.to_csv(os.path.join(WORKING_DIR, "data_with_extracted_text.csv"), index=False)
# Save the cache
with open(CACHE_FILE, "wb") as f:
pickle.dump(cache, f)
logging.info("Incremental processing complete!")
def analyze_image(image_url):
"""Analyze image content using a pre-trained model."""
if not image_url or not isinstance(image_url, str) or not image_url.startswith(('http://', 'https://')):
return None # Skip invalid URLs
try:
response = requests.get(image_url)
response.raise_for_status() # Raise an error for bad responses (4xx, 5xx)
image = Image.open(BytesIO(response.content))
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = preprocess(image).unsqueeze(0)
# Load ResNet50 weights from local cache
weights_path = "/app/models/resnet50-0676ba61.pth"
model = models.resnet50()
model.load_state_dict(torch.load(weights_path))
model.eval()
with torch.no_grad():
output = model(image_tensor)
return output
except Exception as e:
logging.error(f"Error analyzing image: {e}")
return None
logging.info("Analyzing image content...")
combined_df['image_analysis'] = combined_df['image_url'].apply(analyze_image)
# Define scoring criteria for posts
def rate_post(row):
"""Rate a post based on visual appeal, text quality, and predictive metrics."""
# Visual appeal (image clarity, colors, and elements)
visual_appeal = 0.5 # Placeholder for image analysis score
# Text quality (engagement-focused, sentiment-aligned, or informative)
text_quality = 0.3 * row['caption_sentiment'] + 0.2 * len(row['caption'])
# Predictive metrics (virality potential, relevance to trends)
predictive_metrics = 0.2 * row['engagement_rate']
# Combine factors into a weighted score
score = 0.4 * visual_appeal + 0.3 * text_quality + 0.3 * predictive_metrics
return score
logging.info("Rating posts...")
combined_df['post_score'] = combined_df.apply(rate_post, axis=1)
# Log the results
logging.info("Post Ratings:")
print(combined_df[['post_id', 'post_score']].head())
# Train models using all features
logging.info("Training models using all features...")
# Features for model training
features = [
'caption_length', 'hashtag_count', 'sentiment', 'engagement_rate', 'image_text'
]
# Viral Potential Prediction
logging.info("Training viral potential prediction model...")
combined_viral_threshold = combined_df['engagement_rate'].quantile(0.9)
combined_df['viral'] = combined_df['engagement_rate'].apply(lambda x: 1 if x >= combined_viral_threshold else 0)
X = combined_df[features]
y = combined_df['viral']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
viral_model = RandomForestClassifier(random_state=42)
viral_model.fit(X_train, y_train)
y_pred = viral_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
logging.info(f"Viral Potential Model Accuracy: {accuracy:.4f}")
# Engagement Rate Prediction
logging.info("Training engagement rate prediction model...")
X = combined_df[features]
y = combined_df['engagement_rate']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
engagement_model = XGBRegressor(random_state=42)
engagement_model.fit(X_train, y_train)
y_pred = engagement_model.predict(X_test)
mae = mean_absolute_error(y_test, y_pred)
logging.info(f"Engagement Rate Prediction Model - MAE: {mae:.4f}")
# Promotion Strategy
logging.info("Training promotion prediction model...")
promotion_threshold = combined_df['engagement_rate'].quantile(0.8)
combined_df['promote'] = combined_df['engagement_rate'].apply(lambda x: 1 if x >= promotion_threshold else 0)
X = combined_df[features]
y = combined_df['promote']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
promotion_model = LogisticRegression(random_state=42)
promotion_model.fit(X_train, y_train)
y_pred = promotion_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
logging.info(f"Promotion Prediction Model Accuracy: {accuracy:.4f}")
logging.info("Analysis complete!")