amitbhatt6075 commited on
Commit
da61bf8
Β·
1 Parent(s): fd09d9b

feat: Implement real data sources for Pulse page

Browse files
Files changed (2) hide show
  1. core/thunderbird_engine.py +59 -69
  2. requirements.txt +0 -0
core/thunderbird_engine.py CHANGED
@@ -5,122 +5,112 @@ import json
5
  from datetime import datetime
6
  from newsapi import NewsApiClient
7
  from pytrends.request import TrendReq
 
8
  from typing import Dict, Any, Optional
9
- from core.utils import get_supabase_client # Assuming this helper exists
10
 
11
  # --- CONFIGURATION ---
12
  MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', 'models', 'thunderbird_market_predictor_v1.joblib')
13
  NEWS_API_KEY = os.getenv("NEWS_API_KEY")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def get_external_trends() -> dict:
16
- """
17
- Attempts to fetch REAL data. Returns None for fields where data is unavailable.
18
- """
19
  print("πŸš€ [Thunderbird Engine] Fetching REAL external trends...")
20
  results = { "news_headlines": [], "breakout_keyword": None, "trending_audio": None, "platform_shifts": None }
21
 
22
- # NewsAPI
23
  if NEWS_API_KEY:
24
  try:
25
  newsapi = NewsApiClient(api_key=NEWS_API_KEY)
26
  top_headlines = newsapi.get_everything(q='("influencer marketing")', language='en', sort_by='relevancy', page_size=5)
27
  if top_headlines.get('articles'):
28
  results["news_headlines"] = [{"title": a['title'], "url": a['url']} for a in top_headlines['articles']]
29
- except Exception: pass
30
 
31
- # Google Trends for Breakout Keyword
32
  try:
33
  pytrends = TrendReq(hl='en-US', tz=360)
34
  trending_df = pytrends.trending_searches(pn='united_states')
35
  if not trending_df.empty:
36
  results["breakout_keyword"] = trending_df.iloc[0,0]
37
- except Exception: pass
38
-
39
- # Platform Shifts (Real DB Call)
40
- results["platform_shifts"] = get_platform_shifts()
41
 
 
 
42
  return results
43
 
44
- def get_platform_shifts() -> Dict[str, str]:
45
- """
46
- Calculates REAL 7-day shift in platform usage from Supabase.
47
- """
48
- print(" - Calculating REAL platform shifts from DB...")
49
- shifts = { "instagram_shift": "0%", "tiktok_shift": "0%", "youtube_shift": "0%" }
50
- try:
51
- supabase = get_supabase_client()
52
- response = supabase.rpc('get_platform_trend_data_last_14_days').execute()
53
- if not response.data: return shifts
54
-
55
- df = pd.DataFrame(response.data); df['date'] = pd.to_datetime(df['date'])
56
- seven_days_ago = datetime.now() - pd.Timedelta(days=7)
57
- recent = df[df['date'] >= seven_days_ago]; prev = df[df['date'] < seven_days_ago]
58
- if prev.empty or recent.empty: return shifts
59
-
60
- avg_recent = recent.groupby('platform')['usage_count'].mean()
61
- avg_prev = prev.groupby('platform')['usage_count'].mean()
62
-
63
- for p in ['instagram', 'tiktok', 'youtube']:
64
- if p in avg_recent and p in avg_prev and avg_prev[p] > 0:
65
- change = ((avg_recent[p] - avg_prev[p]) / avg_prev[p]) * 100
66
- shifts[f"{p}_shift"] = f"{'+' if change > 0 else ''}{round(change)}%"
67
- return shifts
68
- except Exception:
69
- return shifts
70
-
71
  def predict_niche_trends() -> dict:
72
- """
73
- Uses the REAL trained ML pipeline to predict future interest.
74
- Handles feature name mismatch error.
75
- """
76
  print("\nπŸš€ [Thunderbird Engine] Using REAL ML pipeline for predictions...")
77
-
78
  try:
79
- # 1. Load the ENTIRE pipeline (preprocessor + model)
80
  pipeline = joblib.load(MODEL_PATH)
81
-
82
- # 2. Get niche names from the encoder inside the pipeline
83
  encoder = pipeline.named_steps['preprocessor'].named_transformers_['cat']
84
  all_niches = [cat.replace('niche_', '') for cat in encoder.get_feature_names_out(['niche'])]
85
-
86
- # 3. Prepare future dates for prediction
87
- future_dates = pd.to_datetime(pd.date_range(start=datetime.now(), periods=12, freq='M'))
88
-
89
  predictions = {}
90
  for niche in all_niches:
91
- # 4. Create a DataFrame WITH THE ORIGINAL FEATURE NAMES ('niche', 'month_of_year')
92
  future_df = pd.DataFrame({
93
- 'month_of_year': future_dates.month,
94
- 'niche': [niche] * 12,
95
- 'trend_score': 50 # Assume an average trend score for future prediction
96
  })
97
-
98
- # 5. Use pipeline.predict(). It will handle the one-hot encoding internally.
99
- predicted_values = pipeline.predict(future_df)
100
-
101
- predictions[niche] = [
102
- {"date": dt.strftime('%Y-%m'), "value": max(0, int(val))}
103
- for dt, val in zip(future_dates, predicted_values)
104
- ]
105
-
106
- print(f" - βœ… Successfully generated REAL predictions for niches: {list(predictions.keys())}")
107
  return {"trend_predictions": predictions}
108
-
109
  except Exception as e:
110
  print(f" - ❌ REAL Prediction Failed: {e}. Chart will be empty.")
111
  return {"trend_predictions": {}}
112
 
113
- # --- LLM FUNCTION (No changes needed here) ---
114
  def decode_market_trend(topic: str, llm_instance) -> Dict[str, str]:
115
- print(f"🧠 [Thunderbird] Decoding Trend with REAL AI: {topic}")
116
- offline_response = {"summary": "AI Analyst is currently offline.", "impact": "Could not get real-time analysis.", "strategy": "Please try again in a few moments."}
117
  if not llm_instance: return offline_response
118
  today_date = datetime.now().strftime("%Y-%m-%d")
119
  prompt = f"[INST]You are PulseAI, a Strategy Director. Today is {today_date}. Analyze trend: \"{topic}\". Provide JSON with keys: \"summary\", \"impact\", \"strategy\".[/INST]"
120
  try:
121
  response = llm_instance(prompt, max_tokens=256, temperature=0.6, echo=False)
122
  text = response['choices'][0]['text'].strip()
123
- start = text.find('{'); end = text.rfind('}') + 1
124
  if start != -1 and end != 0: return json.loads(text[start:end])
125
  else: raise ValueError("Invalid JSON from LLM")
126
  except Exception as e:
 
5
  from datetime import datetime
6
  from newsapi import NewsApiClient
7
  from pytrends.request import TrendReq
8
+ from tiktok_scraper_without_watermark.scraper import Scraper
9
  from typing import Dict, Any, Optional
10
+ from core.utils import get_supabase_client
11
 
12
  # --- CONFIGURATION ---
13
  MODEL_PATH = os.path.join(os.path.dirname(__file__), '..', 'models', 'thunderbird_market_predictor_v1.joblib')
14
  NEWS_API_KEY = os.getenv("NEWS_API_KEY")
15
 
16
+ def get_platform_shifts() -> Optional[Dict[str, str]]:
17
+ """Calculates REAL 7-day shift from Supabase."""
18
+ print(" - Calculating REAL platform shifts from DB...")
19
+ try:
20
+ supabase = get_supabase_client()
21
+ response = supabase.rpc('get_platform_trend_data_last_14_days').execute()
22
+ if not response.data or len(response.data) < 2: return None
23
+
24
+ df = pd.DataFrame(response.data); df['date'] = pd.to_datetime(df['date'])
25
+ seven_days_ago = datetime.now() - pd.Timedelta(days=7)
26
+ recent = df[df['date'] >= seven_days_ago]; prev = df[df['date'] < seven_days_ago]
27
+ if prev.empty or recent.empty: return None
28
+
29
+ avg_recent = recent.groupby('platform')['usage_count'].mean()
30
+ avg_prev = prev.groupby('platform')['usage_count'].mean()
31
+ shifts = {}
32
+ for p in ['instagram', 'tiktok', 'youtube']:
33
+ if p in avg_recent and p in avg_prev and avg_prev[p] > 0:
34
+ change = ((avg_recent[p] - avg_prev[p]) / avg_prev[p]) * 100
35
+ shifts[f"{p}_shift"] = f"{'+' if change > 0 else ''}{round(change)}%"
36
+ return shifts
37
+ except Exception as e:
38
+ print(f" - ❌ DB Error calculating shifts: {e}")
39
+ return None
40
+
41
+ def get_trending_audio_from_tiktok() -> Optional[Dict[str, str]]:
42
+ """BEST EFFORT: Scrapes TikTok to find a trending audio."""
43
+ print(" - Attempting to scrape REAL trending audio from TikTok...")
44
+ try:
45
+ scraper = Scraper()
46
+ trending_posts = scraper.trend(count=5)
47
+ for post in trending_posts:
48
+ if post.get('music'):
49
+ music = post['music']
50
+ return {
51
+ "name": f"{music.get('title', 'Unknown')} - {music.get('author', 'Unknown')}",
52
+ "cover_art_url": music.get('cover', 'https://via.placeholder.com/150')
53
+ }
54
+ return None
55
+ except Exception as e:
56
+ print(f" - ⚠️ TikTok scraping failed: {e}")
57
+ return None
58
+
59
  def get_external_trends() -> dict:
60
+ """Fetches REAL data only."""
 
 
61
  print("πŸš€ [Thunderbird Engine] Fetching REAL external trends...")
62
  results = { "news_headlines": [], "breakout_keyword": None, "trending_audio": None, "platform_shifts": None }
63
 
 
64
  if NEWS_API_KEY:
65
  try:
66
  newsapi = NewsApiClient(api_key=NEWS_API_KEY)
67
  top_headlines = newsapi.get_everything(q='("influencer marketing")', language='en', sort_by='relevancy', page_size=5)
68
  if top_headlines.get('articles'):
69
  results["news_headlines"] = [{"title": a['title'], "url": a['url']} for a in top_headlines['articles']]
70
+ except: pass
71
 
 
72
  try:
73
  pytrends = TrendReq(hl='en-US', tz=360)
74
  trending_df = pytrends.trending_searches(pn='united_states')
75
  if not trending_df.empty:
76
  results["breakout_keyword"] = trending_df.iloc[0,0]
77
+ except Exception as e:
78
+ print(f" - ⚠️ Google Trends failed: {e}")
 
 
79
 
80
+ results["trending_audio"] = get_trending_audio_from_tiktok()
81
+ results["platform_shifts"] = get_platform_shifts()
82
  return results
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def predict_niche_trends() -> dict:
85
+ """Uses the REAL trained ML pipeline to predict future interest."""
 
 
 
86
  print("\nπŸš€ [Thunderbird Engine] Using REAL ML pipeline for predictions...")
 
87
  try:
 
88
  pipeline = joblib.load(MODEL_PATH)
 
 
89
  encoder = pipeline.named_steps['preprocessor'].named_transformers_['cat']
90
  all_niches = [cat.replace('niche_', '') for cat in encoder.get_feature_names_out(['niche'])]
91
+ future_dates = pd.to_datetime(pd.date_range(start=datetime.now(), periods=12, freq='ME')) # Fixed 'M' to 'ME'
 
 
 
92
  predictions = {}
93
  for niche in all_niches:
 
94
  future_df = pd.DataFrame({
95
+ 'month_of_year': future_dates.month, 'niche': [niche] * 12, 'trend_score': 50
 
 
96
  })
97
+ predicted_values = pipeline.predict(future_df[['niche', 'trend_score', 'month_of_year']])
98
+ predictions[niche] = [{"date": dt.strftime('%Y-%m'), "value": max(0, int(val))} for dt, val in zip(future_dates, predicted_values)]
 
 
 
 
 
 
 
 
99
  return {"trend_predictions": predictions}
 
100
  except Exception as e:
101
  print(f" - ❌ REAL Prediction Failed: {e}. Chart will be empty.")
102
  return {"trend_predictions": {}}
103
 
 
104
  def decode_market_trend(topic: str, llm_instance) -> Dict[str, str]:
105
+ """Decodes a keyword into a strategy with a clear failure message."""
106
+ offline_response = {"summary": "AI Analyst is currently offline.", "impact": "Could not get real-time analysis.", "strategy": "Please try again later."}
107
  if not llm_instance: return offline_response
108
  today_date = datetime.now().strftime("%Y-%m-%d")
109
  prompt = f"[INST]You are PulseAI, a Strategy Director. Today is {today_date}. Analyze trend: \"{topic}\". Provide JSON with keys: \"summary\", \"impact\", \"strategy\".[/INST]"
110
  try:
111
  response = llm_instance(prompt, max_tokens=256, temperature=0.6, echo=False)
112
  text = response['choices'][0]['text'].strip()
113
+ start, end = text.find('{'), text.rfind('}') + 1
114
  if start != -1 and end != 0: return json.loads(text[start:end])
115
  else: raise ValueError("Invalid JSON from LLM")
116
  except Exception as e:
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ