| |
| from fastapi import FastAPI, HTTPException, Depends |
| from pydantic import BaseModel, constr |
| from typing import List, Dict |
| import logging |
| import requests |
| from io import BytesIO |
| from PIL import Image |
| import pytesseract |
| from textblob import TextBlob |
| import pandas as pd |
| import joblib |
| from sqlalchemy.orm import Session |
| from utils.database import init_db, save_to_db, fetch_posts_from_db, get_db |
| from utils.instaloader_utils import fetch_user_posts, fetch_competitors_posts |
| import torch |
| from torchvision import transforms |
| from transformers import ResNetForImageClassification |
| import re |
| import time |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| |
| app = FastAPI() |
|
|
| |
| init_db() |
|
|
| |
| viral_model = joblib.load("models/viral_potential_model.pkl") |
| engagement_model = joblib.load("models/engagement_rate_model.pkl") |
| promotion_model = joblib.load("models/promotion_strategy_model.pkl") |
|
|
| class UserRequest(BaseModel): |
| username: str |
|
|
| class AnalyzePostRequest(BaseModel): |
| caption: str |
| hashtags: str |
| image_url: str |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| RATE_LIMIT_DELAY = 5 |
| LAST_REQUEST_TIME = 0 |
|
|
|
|
| @app.post("/fetch-posts") |
| async def fetch_posts(user: UserRequest): |
| """ |
| Fetch posts from a given Instagram profile (public data only). |
| """ |
| global LAST_REQUEST_TIME |
|
|
| username = user.username |
| logger.info(f"Fetching posts for user: {username}") |
|
|
| |
| current_time = time.time() |
| if current_time - LAST_REQUEST_TIME < RATE_LIMIT_DELAY: |
| raise HTTPException( |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
| detail="Please wait a few seconds before making another request." |
| ) |
| LAST_REQUEST_TIME = current_time |
|
|
| try: |
| |
| user_posts = fetch_user_posts(username) |
| if not user_posts: |
| logger.warning(f"No posts found for user: {username}") |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="No posts found for the user." |
| ) |
|
|
|
|
| |
| all_posts = await user_posts |
|
|
| |
| if not save_to_db(all_posts): |
| logger.error("Failed to save data to the database.") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Failed to save data to the database." |
| ) |
|
|
| |
| return { |
| "status": "success", |
| "data": all_posts, |
| "message": f"Successfully fetched {len(all_posts)} posts." |
| } |
|
|
| except HTTPException as e: |
| |
| raise e |
| except Exception as e: |
| logger.error(f"Unexpected error fetching posts: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="An unexpected error occurred. Please try again later." |
| ) |
|
|
| @app.post("/analyze") |
| async def analyze(user: UserRequest, db: Session = Depends(get_db)): |
| """ |
| Analyze user and competitor data. |
| """ |
| username = user.username |
| logging.info(f"Analyzing data for user: {username}") |
|
|
| try: |
| |
| user_posts = fetch_posts_from_db(username) |
| if not user_posts: |
| raise HTTPException(status_code=404, detail="No posts found for the user.") |
|
|
| |
| analysis_results = { |
| "viral_potential": predict_viral_potential(user_posts), |
| "top_hashtags": recommend_hashtags(user_posts), |
| "engagement_stats": { |
| "mean_likes": sum(post['likes'] for post in user_posts) / len(user_posts), |
| "mean_comments": sum(post['comments'] for post in user_posts) / len(user_posts) |
| } |
| } |
|
|
| return {"status": "success", "results": analysis_results} |
| except Exception as e: |
| logging.error(f"Error analyzing data: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/analyze-post") |
| async def analyze_post(post: AnalyzePostRequest, db: Session = Depends(get_db)): |
| """ |
| Analyze a single post (caption, hashtags, and image). |
| """ |
| try: |
| |
| response = requests.get(post.image_url) |
| response.raise_for_status() |
| image = Image.open(BytesIO(response.content)) |
|
|
| |
| extracted_text = extract_text_from_image(image) |
|
|
| |
| image_analysis = analyze_image(image) |
|
|
| |
| features = { |
| 'caption_length': len(post.caption), |
| 'hashtag_count': len(post.hashtags.split(",")), |
| 'sentiment': TextBlob(post.caption).sentiment.polarity |
| } |
| features_df = pd.DataFrame([features]) |
|
|
| |
| viral_score = viral_model.predict_proba(features_df)[0][1] |
| engagement_rate = engagement_model.predict(features_df)[0] |
| promote = promotion_model.predict(features_df)[0] |
|
|
| |
| post_data = { |
| "caption": post.caption, |
| "hashtags": post.hashtags, |
| "image_url": post.image_url, |
| "engagement_rate": engagement_rate, |
| "viral_score": viral_score, |
| "promote": bool(promote) |
| } |
| save_to_db([post_data]) |
|
|
| return { |
| "extracted_text": extracted_text, |
| "image_analysis": image_analysis, |
| "viral_score": viral_score, |
| "engagement_rate": engagement_rate, |
| "promote": bool(promote) |
| } |
| except Exception as e: |
| logging.error(f"Error analyzing post: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| def resize_image(image, max_size=(800, 600)): |
| """Resize an image to the specified maximum size.""" |
| image.thumbnail(max_size) |
| return image |
|
|
| def extract_text_from_image(image): |
| """Extract text from an image using OCR.""" |
| try: |
| image = resize_image(image) |
| text = pytesseract.image_to_string(image) |
| return text |
| except Exception as e: |
| logging.error(f"Error extracting text from image: {e}") |
| return "" |
|
|
| def analyze_image(image): |
| """Analyze image content using a pre-trained model.""" |
| try: |
| preprocess = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| image_tensor = preprocess(image).unsqueeze(0) |
| |
| |
| model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") |
| model.eval() |
| |
| with torch.no_grad(): |
| output = model(image_tensor) |
| return output.logits.tolist() |
| except Exception as e: |
| logging.error(f"Error analyzing image: {e}") |
| return None |
|
|
| |
| def predict_viral_potential(posts: List[Dict]) -> List[Dict]: |
| """ |
| Predict viral potential for posts. |
| """ |
| |
| return [{"caption": post["caption"], "viral_score": 0.8} for post in posts] |
|
|
| def recommend_hashtags(posts: List[Dict]) -> List[str]: |
| """ |
| Recommend trending hashtags. |
| """ |
| hashtags = [hashtag for post in posts for hashtag in post['hashtags']] |
| hashtag_counts = Counter(hashtags) |
| return [hashtag for hashtag, _ in hashtag_counts.most_common(10)] |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |