Spaces:
Sleeping
Sleeping
| """ | |
| Model inference engine. | |
| Supports two execution modes: | |
| 1. **HF Spaces mode** -- loads the model onto a ZeroGPU-allocated device using | |
| the ``@spaces.GPU`` decorator. The decorator is applied lazily so the | |
| module can be imported even when the ``spaces`` package is absent. | |
| 2. **Local / demo mode** -- falls back to a smaller model or returns mock | |
| completions when no GPU is available. Useful for development and testing. | |
| The engine applies the chat template expected by the Qwen model family and | |
| injects tool definitions into the conversation so the model can emit | |
| structured tool-call blocks. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from typing import Any | |
| # NOTE: torch is imported lazily inside live-mode methods so that demo mode | |
| # (the default on HF Spaces free tier) does not require torch to be installed. | |
| from model.config import ( | |
| DEVICE_MAP, | |
| FALLBACK_MODEL_ID, | |
| MAX_NEW_TOKENS, | |
| MODEL_ID, | |
| REPETITION_PENALTY, | |
| TEMPERATURE, | |
| TOP_P, | |
| TORCH_DTYPE, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Conditional import of HF Spaces helpers | |
| # --------------------------------------------------------------------------- | |
| try: | |
| import spaces # type: ignore[import-untyped] | |
| _HAS_SPACES = True | |
| except ImportError: | |
| _HAS_SPACES = False | |
| class ModelEngine: | |
| """Thin wrapper around a causal-LM for chat completion with tool support.""" | |
| def __init__(self, model_id: str | None = None, demo_mode: bool = False) -> None: | |
| self.demo_mode = demo_mode | |
| self.model_id = model_id or MODEL_ID | |
| self._model: Any | None = None | |
| self._tokenizer: Any | None = None | |
| self._loaded = False | |
| # --------------------------------------------------------------------- # | |
| # Lazy loading | |
| # --------------------------------------------------------------------- # | |
| def _ensure_loaded(self) -> None: | |
| """Load the model and tokenizer on first use.""" | |
| if self._loaded: | |
| return | |
| if self.demo_mode: | |
| logger.info("Running in demo mode -- no model will be loaded.") | |
| self._loaded = True | |
| return | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| try: | |
| logger.info("Loading model %s ...", self.model_id) | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_id, trust_remote_code=True | |
| ) | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| device_map=DEVICE_MAP, | |
| torch_dtype=TORCH_DTYPE, | |
| trust_remote_code=True, | |
| ) | |
| logger.info("Model %s loaded successfully.", self.model_id) | |
| except Exception: | |
| logger.warning( | |
| "Failed to load %s, falling back to %s", | |
| self.model_id, | |
| FALLBACK_MODEL_ID, | |
| ) | |
| self.model_id = FALLBACK_MODEL_ID | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_id, trust_remote_code=True | |
| ) | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| device_map=DEVICE_MAP, | |
| torch_dtype=TORCH_DTYPE, | |
| trust_remote_code=True, | |
| ) | |
| logger.info("Fallback model %s loaded.", self.model_id) | |
| self._loaded = True | |
| # --------------------------------------------------------------------- # | |
| # Generation | |
| # --------------------------------------------------------------------- # | |
| def generate( | |
| self, | |
| messages: list[dict[str, str]], | |
| tools: list[dict] | None = None, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| temperature: float = TEMPERATURE, | |
| ) -> str: | |
| """Generate a single completion given a chat-style message list. | |
| Parameters | |
| ---------- | |
| messages: | |
| List of ``{"role": ..., "content": ...}`` dicts. | |
| tools: | |
| Optional list of tool JSON-schema dicts to inject into the chat | |
| template so the model can emit ``<tool_call>`` blocks. | |
| max_new_tokens: | |
| Maximum tokens to generate. | |
| temperature: | |
| Sampling temperature. | |
| Returns | |
| ------- | |
| str | |
| The assistant's response text (decoded). | |
| """ | |
| self._ensure_loaded() | |
| if self.demo_mode: | |
| return self._demo_generate(messages) | |
| return self._model_generate(messages, tools, max_new_tokens, temperature) | |
| # --------------------------------------------------------------------- # | |
| # Internal generation paths | |
| # --------------------------------------------------------------------- # | |
| def _model_generate( | |
| self, | |
| messages: list[dict[str, str]], | |
| tools: list[dict] | None, | |
| max_new_tokens: int, | |
| temperature: float, | |
| ) -> str: | |
| # Lazy import torch only in live generation path | |
| import torch | |
| tokenizer = self._tokenizer | |
| model = self._model | |
| # Apply the chat template. Qwen models accept a ``tools`` kwarg. | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tools=tools, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| except TypeError: | |
| # Older template without tool support -- fall back to plain chat. | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=TOP_P, | |
| repetition_penalty=REPETITION_PENALTY, | |
| do_sample=temperature > 0, | |
| ) | |
| # Decode only the newly generated tokens. | |
| generated = output_ids[0][inputs["input_ids"].shape[1] :] | |
| return tokenizer.decode(generated, skip_special_tokens=True).strip() | |
| def _demo_generate(messages: list[dict[str, str]]) -> str: | |
| """Return a canned response for demo / test mode. | |
| The response mimics the Thought / Action / Answer pattern so the parser | |
| and orchestrator can be exercised without a real model. | |
| """ | |
| user_msg = "" | |
| for m in reversed(messages): | |
| if m.get("role") == "user": | |
| user_msg = m.get("content", "") | |
| break | |
| user_lower = user_msg.lower() | |
| if "monte carlo" in user_lower or "simulation" in user_lower: | |
| return ( | |
| "Thought: The user wants a Monte Carlo simulation. " | |
| "I will call the run_monte_carlo tool.\n\n" | |
| '<tool_call>{"name": "run_monte_carlo", ' | |
| '"arguments": {"ticker": "TSLA", "days_forward": 30, ' | |
| '"num_simulations": 1000}}</tool_call>' | |
| ) | |
| if "correlat" in user_lower: | |
| return ( | |
| "Thought: The user wants a correlation analysis. " | |
| "I will use correlate_assets.\n\n" | |
| '<tool_call>{"name": "correlate_assets", ' | |
| '"arguments": {"tickers": ["NVDA", "AMD"], ' | |
| '"period": "6mo"}}</tool_call>' | |
| ) | |
| if any(k in user_lower for k in ("rsi", "macd", "overbought", "momentum", "technical")): | |
| return ( | |
| "Thought: I need technical indicators for this ticker. " | |
| "Let me first fetch market data, then compute indicators.\n\n" | |
| '<tool_call>{"name": "fetch_market_data", ' | |
| '"arguments": {"ticker": "AAPL", "period": "3mo", ' | |
| '"interval": "1d"}}</tool_call>' | |
| ) | |
| if any(k in user_lower for k in ("fed", "rate", "inflation", "economic", "macro")): | |
| return ( | |
| "Thought: The user is asking about macroeconomic conditions. " | |
| "I will fetch the federal funds rate.\n\n" | |
| '<tool_call>{"name": "fetch_economic_data", ' | |
| '"arguments": {"indicator": "federal_funds_rate"}}</tool_call>' | |
| ) | |
| return ( | |
| "Thought: I have enough information to answer the user's question.\n\n" | |
| "Answer: Based on the available data, here is a summary of the " | |
| "market analysis. The current market conditions suggest a mixed " | |
| "outlook with moderate volatility." | |
| ) | |