Spaces:
Runtime error
Runtime error
| """Asset search with embedding-based semantic autocomplete.""" | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from trading_cli.execution.adapters.alpaca import AlpacaAdapter | |
| logger = logging.getLogger(__name__) | |
| class AssetSearchEngine: | |
| """Searchable asset index with optional semantic embeddings. | |
| Supports: | |
| - Symbol search (e.g., "AAPL") | |
| - Company name search (e.g., "Apple") | |
| - Fuzzy/partial matching (e.g., "appl" → Apple) | |
| - Semantic search via embeddings (optional, requires sentence-transformers) | |
| """ | |
| def __init__(self, cache_dir: Path | None = None): | |
| self._assets: list[dict[str, str]] = [] | |
| self._symbol_index: dict[str, dict[str, str]] = {} | |
| self._lock = threading.Lock() | |
| self._cache_dir = cache_dir or Path.home() / ".cache" / "trading_cli" | |
| self._cache_file = self._cache_dir / "assets.json" | |
| self._embeddings = None | |
| self._embedding_model = None | |
| self._initialized = False | |
| def load_assets(self, adapter: AlpacaAdapter) -> int: | |
| """Load assets from adapter (with caching). | |
| Returns: | |
| Number of assets loaded. | |
| """ | |
| # Try cache first | |
| if self._load_from_cache(): | |
| logger.info("Loaded %d assets from cache", len(self._assets)) | |
| self._initialized = True | |
| return len(self._assets) | |
| # Fetch from adapter | |
| try: | |
| assets = adapter.get_all_assets() | |
| if assets: | |
| with self._lock: | |
| self._assets = assets | |
| self._symbol_index = { | |
| asset["symbol"].upper(): asset for asset in assets | |
| } | |
| self._save_to_cache() | |
| logger.info("Loaded %d assets from adapter", len(assets)) | |
| self._initialized = True | |
| return len(assets) | |
| except Exception as exc: | |
| logger.warning("Failed to load assets: %s", exc) | |
| return 0 | |
| def _load_from_cache(self) -> bool: | |
| """Load cached assets. Returns True if successful.""" | |
| if not self._cache_file.exists(): | |
| return False | |
| try: | |
| data = json.loads(self._cache_file.read_text()) | |
| with self._lock: | |
| self._assets = data["assets"] | |
| self._symbol_index = { | |
| asset["symbol"].upper(): asset for asset in self._assets | |
| } | |
| return True | |
| except Exception as exc: | |
| logger.warning("Cache load failed: %s", exc) | |
| return False | |
| def _save_to_cache(self) -> None: | |
| """Save assets to cache.""" | |
| try: | |
| self._cache_dir.mkdir(parents=True, exist_ok=True) | |
| self._cache_file.write_text( | |
| json.dumps({"assets": self._assets}, indent=2) | |
| ) | |
| except Exception as exc: | |
| logger.warning("Cache save failed: %s", exc) | |
| def search( | |
| self, | |
| query: str, | |
| max_results: int = 10, | |
| use_semantic: bool = True, | |
| ) -> list[dict[str, str]]: | |
| """Search assets by symbol or company name. | |
| Args: | |
| query: Search query (symbol fragment or company name). | |
| max_results: Maximum number of results to return. | |
| use_semantic: Whether to use semantic embeddings if available. | |
| Returns: | |
| List of dicts with 'symbol', 'name', and optionally 'score'. | |
| """ | |
| if not query.strip(): | |
| return [] | |
| query_upper = query.upper().strip() | |
| query_lower = query.lower().strip() | |
| results: list[dict[str, str]] = [] | |
| with self._lock: | |
| # Exact symbol match (highest priority) | |
| if query_upper in self._symbol_index: | |
| asset = self._symbol_index[query_upper] | |
| results.append({ | |
| "symbol": asset["symbol"], | |
| "name": asset["name"], | |
| "score": 1.0, | |
| }) | |
| if len(results) >= max_results: | |
| return results | |
| # Text-based matching (symbol prefix or name substring) | |
| for asset in self._assets: | |
| symbol = asset["symbol"] | |
| name = asset.get("name", "") | |
| # Symbol starts with query | |
| if symbol.upper().startswith(query_upper): | |
| score = 0.9 if symbol.upper() == query_upper else 0.8 | |
| results.append({ | |
| "symbol": symbol, | |
| "name": name, | |
| "score": score, | |
| }) | |
| if len(results) >= max_results: | |
| return results | |
| # Name contains query (case-insensitive) | |
| if len(results) < max_results and len(query_lower) >= 2: | |
| for asset in self._assets: | |
| name = asset.get("name", "") | |
| if query_lower in name.lower(): | |
| # Check not already in results | |
| if not any(r["symbol"] == asset["symbol"] for r in results): | |
| results.append({ | |
| "symbol": asset["symbol"], | |
| "name": name, | |
| "score": 0.7, | |
| }) | |
| if len(results) >= max_results: | |
| return results | |
| # Semantic search (optional, for fuzzy matching) | |
| if use_semantic and len(results) < max_results: | |
| semantic_results = self._search_semantic(query, max_results - len(results)) | |
| # Merge, avoiding duplicates | |
| existing_symbols = {r["symbol"] for r in results} | |
| for sr in semantic_results: | |
| if sr["symbol"] not in existing_symbols: | |
| results.append(sr) | |
| if len(results) >= max_results: | |
| break | |
| return results[:max_results] | |
| def _search_semantic( | |
| self, | |
| query: str, | |
| max_results: int, | |
| ) -> list[dict[str, str]]: | |
| """Search using semantic similarity (requires embeddings).""" | |
| if not self._embedding_model or not self._embeddings: | |
| return [] | |
| try: | |
| # Encode query | |
| query_embedding = self._embedding_model.encode( | |
| [query], | |
| normalize_embeddings=True, | |
| )[0] | |
| # Compute cosine similarity | |
| import numpy as np | |
| embeddings_matrix = np.array(self._embeddings) | |
| similarities = embeddings_matrix @ query_embedding | |
| # Get top results | |
| top_indices = np.argsort(similarities)[::-1][:max_results] | |
| results = [] | |
| for idx in top_indices: | |
| if similarities[idx] < 0.3: # Minimum similarity threshold | |
| break | |
| asset = self._assets[idx] | |
| results.append({ | |
| "symbol": asset["symbol"], | |
| "name": asset["name"], | |
| "score": float(similarities[idx]), | |
| }) | |
| return results | |
| except Exception as exc: | |
| logger.warning("Semantic search failed: %s", exc) | |
| return [] | |
| def load_embedding_model(self, model_name: str = "all-MiniLM-L6-v2"): | |
| """Load a sentence transformer model for semantic search. | |
| This is optional and will only be used if successfully loaded. | |
| Falls back to text-based matching if unavailable. | |
| Args: | |
| model_name: Name of the sentence-transformers model to use. | |
| Default is 'all-MiniLM-L6-v2' (80MB, fast, good quality). | |
| """ | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| logger.info("Loading embedding model '%s'...", model_name) | |
| self._embedding_model = SentenceTransformer(model_name) | |
| # Precompute embeddings for all assets | |
| texts = [ | |
| f"{asset['symbol']} {asset['name']}" | |
| for asset in self._assets | |
| ] | |
| embeddings = self._embedding_model.encode( | |
| texts, | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| ) | |
| self._embeddings = embeddings.tolist() | |
| logger.info( | |
| "Loaded embedding model: %d assets embedded", | |
| len(self._embeddings), | |
| ) | |
| except ImportError: | |
| logger.info( | |
| "sentence-transformers not installed. " | |
| "Install with: uv add sentence-transformers (optional)" | |
| ) | |
| except Exception as exc: | |
| logger.warning("Failed to load embedding model: %s", exc) | |
| def is_ready(self) -> bool: | |
| """Whether the search engine has assets loaded.""" | |
| return self._initialized | |
| def has_semantic_search(self) -> bool: | |
| """Whether semantic search is available.""" | |
| return self._embedding_model is not None and self._embeddings is not None | |