File size: 8,797 Bytes
852035a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3320861
 
852035a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3320861
 
 
852035a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
Model inference engine.

Supports two execution modes:

1. **HF Spaces mode** -- loads the model onto a ZeroGPU-allocated device using
   the ``@spaces.GPU`` decorator.  The decorator is applied lazily so the
   module can be imported even when the ``spaces`` package is absent.

2. **Local / demo mode** -- falls back to a smaller model or returns mock
   completions when no GPU is available.  Useful for development and testing.

The engine applies the chat template expected by the Qwen model family and
injects tool definitions into the conversation so the model can emit
structured tool-call blocks.
"""

from __future__ import annotations

import logging
import os
from typing import Any

# NOTE: torch is imported lazily inside live-mode methods so that demo mode
# (the default on HF Spaces free tier) does not require torch to be installed.

from model.config import (
    DEVICE_MAP,
    FALLBACK_MODEL_ID,
    MAX_NEW_TOKENS,
    MODEL_ID,
    REPETITION_PENALTY,
    TEMPERATURE,
    TOP_P,
    TORCH_DTYPE,
)

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Conditional import of HF Spaces helpers
# ---------------------------------------------------------------------------
try:
    import spaces  # type: ignore[import-untyped]

    _HAS_SPACES = True
except ImportError:
    _HAS_SPACES = False


class ModelEngine:
    """Thin wrapper around a causal-LM for chat completion with tool support."""

    def __init__(self, model_id: str | None = None, demo_mode: bool = False) -> None:
        self.demo_mode = demo_mode
        self.model_id = model_id or MODEL_ID
        self._model: Any | None = None
        self._tokenizer: Any | None = None
        self._loaded = False

    # --------------------------------------------------------------------- #
    # Lazy loading
    # --------------------------------------------------------------------- #
    def _ensure_loaded(self) -> None:
        """Load the model and tokenizer on first use."""
        if self._loaded:
            return

        if self.demo_mode:
            logger.info("Running in demo mode -- no model will be loaded.")
            self._loaded = True
            return

        from transformers import AutoModelForCausalLM, AutoTokenizer

        try:
            logger.info("Loading model %s ...", self.model_id)
            self._tokenizer = AutoTokenizer.from_pretrained(
                self.model_id, trust_remote_code=True
            )
            self._model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                device_map=DEVICE_MAP,
                torch_dtype=TORCH_DTYPE,
                trust_remote_code=True,
            )
            logger.info("Model %s loaded successfully.", self.model_id)
        except Exception:
            logger.warning(
                "Failed to load %s, falling back to %s",
                self.model_id,
                FALLBACK_MODEL_ID,
            )
            self.model_id = FALLBACK_MODEL_ID
            self._tokenizer = AutoTokenizer.from_pretrained(
                self.model_id, trust_remote_code=True
            )
            self._model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                device_map=DEVICE_MAP,
                torch_dtype=TORCH_DTYPE,
                trust_remote_code=True,
            )
            logger.info("Fallback model %s loaded.", self.model_id)

        self._loaded = True

    # --------------------------------------------------------------------- #
    # Generation
    # --------------------------------------------------------------------- #
    def generate(
        self,
        messages: list[dict[str, str]],
        tools: list[dict] | None = None,
        max_new_tokens: int = MAX_NEW_TOKENS,
        temperature: float = TEMPERATURE,
    ) -> str:
        """Generate a single completion given a chat-style message list.

        Parameters
        ----------
        messages:
            List of ``{"role": ..., "content": ...}`` dicts.
        tools:
            Optional list of tool JSON-schema dicts to inject into the chat
            template so the model can emit ``<tool_call>`` blocks.
        max_new_tokens:
            Maximum tokens to generate.
        temperature:
            Sampling temperature.

        Returns
        -------
        str
            The assistant's response text (decoded).
        """
        self._ensure_loaded()

        if self.demo_mode:
            return self._demo_generate(messages)

        return self._model_generate(messages, tools, max_new_tokens, temperature)

    # --------------------------------------------------------------------- #
    # Internal generation paths
    # --------------------------------------------------------------------- #
    def _model_generate(
        self,
        messages: list[dict[str, str]],
        tools: list[dict] | None,
        max_new_tokens: int,
        temperature: float,
    ) -> str:
        # Lazy import torch only in live generation path
        import torch

        tokenizer = self._tokenizer
        model = self._model

        # Apply the chat template.  Qwen models accept a ``tools`` kwarg.
        try:
            prompt = tokenizer.apply_chat_template(
                messages,
                tools=tools,
                tokenize=False,
                add_generation_prompt=True,
            )
        except TypeError:
            # Older template without tool support -- fall back to plain chat.
            prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=TOP_P,
                repetition_penalty=REPETITION_PENALTY,
                do_sample=temperature > 0,
            )

        # Decode only the newly generated tokens.
        generated = output_ids[0][inputs["input_ids"].shape[1] :]
        return tokenizer.decode(generated, skip_special_tokens=True).strip()

    @staticmethod
    def _demo_generate(messages: list[dict[str, str]]) -> str:
        """Return a canned response for demo / test mode.

        The response mimics the Thought / Action / Answer pattern so the parser
        and orchestrator can be exercised without a real model.
        """
        user_msg = ""
        for m in reversed(messages):
            if m.get("role") == "user":
                user_msg = m.get("content", "")
                break

        user_lower = user_msg.lower()

        if "monte carlo" in user_lower or "simulation" in user_lower:
            return (
                "Thought: The user wants a Monte Carlo simulation. "
                "I will call the run_monte_carlo tool.\n\n"
                '<tool_call>{"name": "run_monte_carlo", '
                '"arguments": {"ticker": "TSLA", "days_forward": 30, '
                '"num_simulations": 1000}}</tool_call>'
            )
        if "correlat" in user_lower:
            return (
                "Thought: The user wants a correlation analysis. "
                "I will use correlate_assets.\n\n"
                '<tool_call>{"name": "correlate_assets", '
                '"arguments": {"tickers": ["NVDA", "AMD"], '
                '"period": "6mo"}}</tool_call>'
            )
        if any(k in user_lower for k in ("rsi", "macd", "overbought", "momentum", "technical")):
            return (
                "Thought: I need technical indicators for this ticker. "
                "Let me first fetch market data, then compute indicators.\n\n"
                '<tool_call>{"name": "fetch_market_data", '
                '"arguments": {"ticker": "AAPL", "period": "3mo", '
                '"interval": "1d"}}</tool_call>'
            )
        if any(k in user_lower for k in ("fed", "rate", "inflation", "economic", "macro")):
            return (
                "Thought: The user is asking about macroeconomic conditions. "
                "I will fetch the federal funds rate.\n\n"
                '<tool_call>{"name": "fetch_economic_data", '
                '"arguments": {"indicator": "federal_funds_rate"}}</tool_call>'
            )

        return (
            "Thought: I have enough information to answer the user's question.\n\n"
            "Answer: Based on the available data, here is a summary of the "
            "market analysis. The current market conditions suggest a mixed "
            "outlook with moderate volatility."
        )