| |
| import gradio as gr |
| import asyncio |
| import os |
| import subprocess |
| import sys |
| os.environ["GRADIO_SSR_MODE"] = "False" |
|
|
| import pandas as pd |
| from datetime import datetime, timedelta |
| from pathlib import Path |
|
|
| from scraper import NewsScraper |
| from extractor import ContentExtractor |
| from features import Features, DataPipeline |
| from model import StockNewsModel |
| import threading |
| import time |
| import traceback |
| try: |
| from zoneinfo import ZoneInfo |
| except ImportError: |
| from backports.zoneinfo import ZoneInfo |
|
|
|
|
| def _env_flag(name: str, default: str = "0") -> bool: |
| value = os.getenv(name, default).strip().lower() |
| return value in {"1", "true", "yes", "on"} |
|
|
|
|
| |
| APP_DIR = Path(__file__).resolve().parent |
| MODEL_PATH = APP_DIR / "NSEI_model.pkl" |
| TRAIN_SCRIPT = APP_DIR / "train.py" |
| MODEL_LOCK = threading.Lock() |
| LAST_RETRAIN_DATE = None |
| BOOTSTRAP_STARTED = False |
|
|
|
|
| def _load_model_from_disk(): |
| global model |
| try: |
| loaded = StockNewsModel.load(str(MODEL_PATH)) |
| with MODEL_LOCK: |
| model = loaded |
| print(f"Loaded pre-trained model from {MODEL_PATH}") |
| return True |
| except Exception: |
| print(f"Warning: Could not load model from {MODEL_PATH}. Models need to be trained first.") |
| with MODEL_LOCK: |
| model = None |
| return False |
|
|
|
|
| def _run_training_job(): |
| print(f"Starting scheduled retrain at {datetime.now(IST).isoformat()}") |
| subprocess.run([sys.executable, str(TRAIN_SCRIPT)], cwd=str(APP_DIR), check=True) |
| if _load_model_from_disk(): |
| print("Reloaded retrained model successfully.") |
| return True |
| return False |
|
|
|
|
| _load_model_from_disk() |
|
|
| async def fetch_and_predict(ticker="^NSEI", days_back=7): |
| with MODEL_LOCK: |
| current_model = model |
| if not current_model: |
| return {"error": "Model not loaded. Please train the model first."} |
|
|
| scraper = NewsScraper(limit=450) |
| extractor = ContentExtractor() |
| features = Features(ticker) |
| |
| |
| lookback = datetime.now() - timedelta(days=days_back) |
| articles = await scraper.scrape(ticker, lookback) |
| |
| if not articles: |
| return {"message": f"No recent news found for {ticker}."} |
|
|
| |
| df = pd.DataFrame(articles) |
| |
| df['content'] = df['description'].fillna('') |
| df['ts'] = pd.to_datetime(df['timestamp'], errors='coerce', utc=True) |
| df = df.dropna(subset=['ts']) |
| df['date'] = df['ts'].dt.date |
| |
| if df.empty: |
| return {"message": "No valid timestamps found in articles."} |
|
|
| |
| df_init_feats = features.build(df) |
| pipeline = DataPipeline(ticker, train_days=0, test_days=0) |
| price_df = pipeline.get_prices(datetime.now() - timedelta(days=30)) |
| df_init_feats = pipeline._add_market_context(df_init_feats, price_df) |
|
|
| |
| for col in current_model.feature_names: |
| if col not in df_init_feats.columns: |
| df_init_feats[col] = 0.0 |
| X_init = df_init_feats[current_model.feature_names].fillna(0).replace([float('inf'), float('-inf')], 0) |
| init_results = current_model.predict_new(X_init) |
| |
| |
| df_ranked = pd.concat([df, init_results], axis=1) |
| df_ranked = df_ranked.sort_values(by='impact', ascending=False) |
|
|
| |
| |
| survivors = df_ranked.head(12).to_dict('records') |
| print(f"[Pipeline] High-impact filtering complete. Deep extracting {len(survivors)} survivor(s).") |
|
|
| |
| survivors = await extractor.extract_all(survivors) |
| |
| |
| top_articles = [] |
| for i, row in enumerate(survivors[:10]): |
| title = str(row.get('title', '')) |
| source_name = str(row.get('source', 'Unknown')) |
| image_url = row.get('image', '') or extractor._build_dynamic_image_url( |
| title=title, |
| source=source_name, |
| ticker=ticker, |
| ) |
| |
| |
| snippet = str(row.get('description', ''))[:400] |
| |
| top_articles.append({ |
| "id": i, |
| "title": title, |
| "source": source_name, |
| "date": row.get('pub_date', ''), |
| "url": row.get('link', ''), |
| "image": image_url, |
| "impact_score": round(row.get('impact', 0), 3), |
| "sentiment": round(row.get('sent_combined', 0), 3), |
| "content": f"<p>{snippet}</p>" |
| }) |
|
|
| return top_articles |
|
|
| |
| cached_headlines = None |
| last_refresh_date = None |
| CACHE_LOCK = threading.Lock() |
| IST = ZoneInfo("Asia/Kolkata") |
| _REFRESH_THREAD_STARTED = False |
| _INITIAL_REFRESH_STARTED = False |
|
|
| def update_cache(ticker="^NSEI"): |
| global cached_headlines, last_refresh_date |
| print(f"Fetching new daily market insights for {ticker}...") |
| try: |
| data = asyncio.run(fetch_and_predict(ticker, days_back=3)) |
| with CACHE_LOCK: |
| cached_headlines = data |
| last_refresh_date = datetime.now(IST).date() |
| print("Market insights cache successfully updated.") |
| except Exception as e: |
| print(f"Error fetching insights: {e}") |
| traceback.print_exc() |
|
|
|
|
| def _refresh_cache_loop(): |
| global last_refresh_date, LAST_RETRAIN_DATE |
| while True: |
| try: |
| now = datetime.now(IST) |
| |
| if now.weekday() < 5 and (now.hour > 15 or (now.hour == 15 and now.minute >= 30)): |
| if LAST_RETRAIN_DATE != now.date(): |
| if _run_training_job(): |
| LAST_RETRAIN_DATE = now.date() |
| update_cache() |
| time.sleep(300) |
| except Exception as e: |
| print(f"Refresh loop error: {e}") |
| time.sleep(300) |
|
|
|
|
| def _ensure_refresh_thread_started(): |
| global _REFRESH_THREAD_STARTED |
| if _REFRESH_THREAD_STARTED: |
| return |
| with CACHE_LOCK: |
| if _REFRESH_THREAD_STARTED: |
| return |
| threading.Thread( |
| target=_refresh_cache_loop, |
| daemon=True, |
| name="insights-refresh-scheduler", |
| ).start() |
| _REFRESH_THREAD_STARTED = True |
|
|
|
|
| def _start_initial_refresh(ticker="^NSEI"): |
| global _INITIAL_REFRESH_STARTED |
| if _INITIAL_REFRESH_STARTED: |
| return |
| with CACHE_LOCK: |
| if _INITIAL_REFRESH_STARTED: |
| return |
|
|
| def _worker(): |
| global _INITIAL_REFRESH_STARTED |
| try: |
| update_cache(ticker) |
| finally: |
| with CACHE_LOCK: |
| if cached_headlines is None: |
| _INITIAL_REFRESH_STARTED = False |
|
|
| threading.Thread( |
| target=_worker, |
| daemon=True, |
| name="insights-initial-refresh", |
| ).start() |
| _INITIAL_REFRESH_STARTED = True |
|
|
|
|
| _ensure_refresh_thread_started() |
|
|
|
|
| def _refresh_loop(): |
| |
| _refresh_cache_loop() |
|
|
| def get_predictions(ticker="^NSEI"): |
| |
| with CACHE_LOCK: |
| cached = cached_headlines |
| if cached is not None: |
| return cached |
|
|
| if _env_flag("HF_EAGER_FIRST_REFRESH"): |
| update_cache(ticker) |
| with CACHE_LOCK: |
| cached = cached_headlines |
| if cached is not None: |
| return cached |
|
|
| global BOOTSTRAP_STARTED |
| with MODEL_LOCK: |
| current_model = model |
| if current_model is None and not BOOTSTRAP_STARTED: |
| def _bootstrap_worker(): |
| global BOOTSTRAP_STARTED |
| try: |
| if _run_training_job(): |
| update_cache(ticker) |
| finally: |
| BOOTSTRAP_STARTED = False |
|
|
| threading.Thread( |
| target=_bootstrap_worker, |
| daemon=True, |
| name="model-bootstrap", |
| ).start() |
| BOOTSTRAP_STARTED = True |
| elif current_model is not None: |
| |
| _start_initial_refresh(ticker) |
|
|
| return [{"message": "Generating insights for the day... Check back in a minute."}] |
| |
|
|
| def demo(): |
| with gr.Blocks(title="Miscellaneous News Impact Analyzer") as app: |
| gr.Markdown("# Miscellaneous Model Backend") |
| |
| with gr.Row(): |
| ticker_input = gr.Textbox(label="Ticker Symbol", value="^NSEI") |
| |
| btn = gr.Button("Fetch Latest Impactful News") |
| output = gr.JSON(label="Top Articles") |
| |
| btn.click( |
| fn=get_predictions, |
| inputs=[ticker_input], |
| outputs=[output], |
| api_name="predict" |
| ) |
| |
| return app |
|
|
| app = demo() |
|
|
| if __name__ == "__main__": |
| app.queue().launch( |
| server_name="0.0.0.0", |
| server_port=int(os.environ.get("PORT", "7860")), |
| ssr_mode=False, |
| show_error=True |
| ) |
|
|