Spaces:
Runtime error
Runtime error
| """Sentiment analysis screen — interactive FinBERT analysis per symbol.""" | |
| from __future__ import annotations | |
| import threading | |
| from textual.app import ComposeResult | |
| from textual.binding import Binding | |
| from textual.screen import Screen | |
| from textual.widgets import Header, Input, Label, DataTable, Static | |
| from textual.containers import Vertical | |
| from textual.reactive import reactive | |
| from textual import work | |
| from rich.text import Text | |
| from trading_cli.sentiment.aggregator import get_sentiment_summary | |
| from trading_cli.widgets.ordered_footer import OrderedFooter | |
| class SentimentScoreDisplay(Static): | |
| """Displays sentiment score with a simple label.""" | |
| score: reactive[float] = reactive(0.0) | |
| symbol: reactive[str] = reactive("") | |
| positive_count: reactive[int] = reactive(0) | |
| negative_count: reactive[int] = reactive(0) | |
| neutral_count: reactive[int] = reactive(0) | |
| dominant: reactive[str] = reactive("NEUTRAL") | |
| def render(self) -> str: | |
| if not self.symbol: | |
| return "" | |
| dom_style = {"POSITIVE": "green", "NEGATIVE": "red", "NEUTRAL": "yellow"}.get(self.dominant, "white") | |
| return ( | |
| f"[bold]{self.symbol}[/bold] — " | |
| f"[{dom_style}]{self.dominant}[/{dom_style}] " | |
| f"(score: [bold]{self.score:+.3f}[/bold], " | |
| f"+{self.positive_count} / −{self.negative_count} / ={self.neutral_count})" | |
| ) | |
| class SentimentScreen(Screen): | |
| """Screen ID 5 — on-demand FinBERT sentiment analysis.""" | |
| BINDINGS = [ | |
| Binding("r", "refresh_symbol", "Refresh", show=False), | |
| ] | |
| _current_symbol: str = "" | |
| _analysis_task: str = "" # Track the latest symbol being analyzed | |
| def compose(self) -> ComposeResult: | |
| yield Header(show_clock=True) | |
| with Vertical(): | |
| # Create asset autocomplete input | |
| app = self.app | |
| if hasattr(app, 'asset_search') and app.asset_search.is_ready: | |
| from trading_cli.widgets.asset_autocomplete import create_asset_autocomplete | |
| input_widget, autocomplete_widget = create_asset_autocomplete( | |
| app.asset_search, | |
| placeholder="Search by symbol or company name… (Tab to complete)", | |
| id="sent-input", | |
| ) | |
| yield input_widget | |
| yield autocomplete_widget | |
| else: | |
| yield Input(placeholder="Search by symbol or company name…", id="sent-input") | |
| yield Label("", id="sent-loading-status") | |
| yield SentimentScoreDisplay(id="sent-summary") | |
| yield DataTable(id="sent-table", cursor_type="row") | |
| yield OrderedFooter() | |
| def on_mount(self) -> None: | |
| tbl = self.query_one("#sent-table", DataTable) | |
| tbl.add_column("Headline", key="headline") | |
| tbl.add_column("Label", key="label") | |
| tbl.add_column("Score", key="score") | |
| self.query_one("#sent-input", Input).focus() | |
| self._clear_loading_status() | |
| # ------------------------------------------------------------------ | |
| # Loading status helpers | |
| # ------------------------------------------------------------------ | |
| def _set_loading_status(self, text: str) -> None: | |
| """Update the status label text.""" | |
| def _update(): | |
| try: | |
| self.query_one("#sent-loading-status", Label).update(f"[dim]{text}[/dim]") | |
| except Exception: | |
| pass | |
| # Only use call_from_thread if we're in a background thread | |
| if threading.get_ident() != self.app._thread_id: | |
| self.app.call_from_thread(_update) | |
| else: | |
| _update() | |
| def _clear_loading_status(self) -> None: | |
| """Clear the status label.""" | |
| def _update(): | |
| try: | |
| self.query_one("#sent-loading-status", Label).update("") | |
| except Exception: | |
| pass | |
| # Only use call_from_thread if we're in a background thread | |
| if threading.get_ident() != self.app._thread_id: | |
| self.app.call_from_thread(_update) | |
| else: | |
| _update() | |
| # ------------------------------------------------------------------ | |
| # Event handlers | |
| # ------------------------------------------------------------------ | |
| def on_input_submitted(self, event: Input.Submitted) -> None: | |
| value = event.value.strip() | |
| if not value: | |
| return | |
| # Extract symbol from autocomplete format "SYMBOL — Company Name" | |
| if " — " in value: | |
| symbol = value.split(" — ")[0].strip().upper() | |
| else: | |
| symbol = value.upper() | |
| if symbol: | |
| self._current_symbol = symbol | |
| self._run_analysis(symbol) | |
| def action_refresh_symbol(self) -> None: | |
| if self._current_symbol: | |
| self._run_analysis(self._current_symbol) | |
| # ------------------------------------------------------------------ | |
| # Analysis (background thread) | |
| # ------------------------------------------------------------------ | |
| def _run_analysis(self, symbol: str) -> None: | |
| """Kick off background analysis.""" | |
| # Update the task tracker to the latest symbol (cancels previous tasks) | |
| self._analysis_task = symbol | |
| # Clear the table to show we're working on a new request | |
| tbl = self.query_one("#sent-table", DataTable) | |
| tbl.clear() | |
| # Reset summary display | |
| lbl = self.query_one("#sent-summary", SentimentScoreDisplay) | |
| lbl.symbol = "" | |
| lbl.score = 0.0 | |
| self._do_analysis(symbol) | |
| def _do_analysis(self, symbol: str) -> None: | |
| """Analyze sentiment for a symbol (non-blocking, allows cancellation).""" | |
| analyzer = getattr(self.app, "finbert", None) | |
| db_conn = getattr(self.app, "db_conn", None) | |
| # Check if this task has been superseded by a newer request | |
| def is_cancelled() -> bool: | |
| return self._analysis_task != symbol | |
| # Attempt to reload FinBERT if not loaded | |
| if analyzer and not analyzer.is_loaded: | |
| self._set_loading_status("Loading FinBERT model…") | |
| success = analyzer.reload( | |
| progress_callback=lambda msg: self._set_loading_status(msg), | |
| ) | |
| if not success: | |
| error_msg = analyzer.load_error or "Unknown error" | |
| self.app.call_from_thread( | |
| self.app.notify, | |
| f"FinBERT failed to load: {error_msg}", | |
| severity="error", | |
| ) | |
| self._set_loading_status(f"Failed: {error_msg}") | |
| return | |
| # Check cancellation after model loading | |
| if is_cancelled(): | |
| return | |
| self._set_loading_status(f"Fetching headlines for {symbol}…") | |
| from trading_cli.data.news import fetch_headlines | |
| headlines = fetch_headlines(symbol, max_articles=20) | |
| # Check cancellation after network call | |
| if is_cancelled(): | |
| return | |
| if not headlines: | |
| self.app.call_from_thread( | |
| self.app.notify, f"No headlines found for {symbol}", severity="warning", | |
| ) | |
| self._clear_loading_status() | |
| return | |
| self._set_loading_status("Running sentiment analysis…") | |
| results = [] | |
| if analyzer and analyzer.is_loaded: | |
| if db_conn: | |
| results = analyzer.analyze_with_cache(headlines, db_conn) | |
| else: | |
| results = analyzer.analyze_batch(headlines) | |
| else: | |
| results = [{"label": "neutral", "score": 0.5}] * len(headlines) | |
| # Check cancellation after heavy computation | |
| if is_cancelled(): | |
| return | |
| self._clear_loading_status() | |
| # Only update UI if this is still the latest task | |
| if not is_cancelled(): | |
| # Dispatch UI update back to main thread | |
| self.app.call_from_thread(self._display_results, symbol, headlines, results) | |
| # ------------------------------------------------------------------ | |
| # Display | |
| # ------------------------------------------------------------------ | |
| def _display_results(self, symbol: str, headlines: list[str], results: list[dict]) -> None: | |
| summary = get_sentiment_summary(results) | |
| # Update summary | |
| lbl = self.query_one("#sent-summary", SentimentScoreDisplay) | |
| lbl.symbol = symbol | |
| lbl.score = summary["score"] | |
| lbl.positive_count = summary["positive_count"] | |
| lbl.negative_count = summary["negative_count"] | |
| lbl.neutral_count = summary["neutral_count"] | |
| lbl.dominant = summary["dominant"].upper() | |
| tbl = self.query_one("#sent-table", DataTable) | |
| tbl.clear() | |
| for headline, result in zip(headlines, results): | |
| label = result.get("label", "neutral") | |
| score_val = result.get("score", 0.5) | |
| label_style = {"positive": "green", "negative": "red", "neutral": "yellow"}.get(label, "white") | |
| tbl.add_row( | |
| headline[:80], | |
| Text(label.upper(), style=f"bold {label_style}"), | |
| Text(f"{score_val:.3f}", style=label_style), | |
| ) | |