LittleMonkeyLab commited on
Commit
f212919
·
verified ·
1 Parent(s): 85cbb3f

Upload chatbot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chatbot.py +428 -88
chatbot.py CHANGED
@@ -1,53 +1,302 @@
1
  """
2
- Chatbot engine with RAG pipeline for the AI Trading Experiment.
3
- Uses HuggingFace Inference API.
 
 
 
 
 
4
  """
5
 
6
  import os
7
  import sys
8
  import random
9
- from typing import Optional, List, Tuple
 
10
  from dataclasses import dataclass
 
11
 
12
- # pysqlite3 workaround for HF Spaces
13
  try:
14
  __import__('pysqlite3')
15
  sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
16
  except ImportError:
17
- pass
18
 
 
19
  from langchain_community.document_loaders import TextLoader
20
  from langchain_community.vectorstores import Chroma
21
  from langchain_community.embeddings import HuggingFaceEmbeddings
22
  from langchain_text_splitters import CharacterTextSplitter
23
- from huggingface_hub import InferenceClient
24
 
25
  from config import Scenario, ResearcherControlledParams, ParticipantVisibleParams
26
 
27
 
 
28
  KNOWLEDGE_BASE_DIR = "knowledge_base"
29
  VECTOR_DB_DIR = "db/vectorstore"
30
- MODEL = "HuggingFaceH4/zephyr-7b-beta"
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @dataclass
34
  class ChatResponse:
35
  """Response from the chatbot."""
36
  message: str
37
  is_proactive: bool
38
- confidence_level: str
39
  sources_used: List[str]
40
 
41
 
 
 
42
  class TradingChatbot:
43
- """AI Chatbot for the trading experiment."""
 
 
 
44
 
45
- def __init__(self):
46
- self.client = InferenceClient(token=os.getenv("HF_TOKEN"))
47
  self.vectorstore = None
48
  self.chat_history: List[Tuple[str, str]] = []
49
  self._initialize_knowledge_base()
50
- print(f"Using model: {MODEL}")
51
 
52
  def _initialize_knowledge_base(self):
53
  """Load and index the knowledge base documents."""
@@ -67,12 +316,14 @@ class TradingChatbot:
67
  print("Warning: No knowledge base documents found.")
68
  return
69
 
 
70
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
71
  split_docs = splitter.split_documents(docs)
72
 
73
  texts = [doc.page_content for doc in split_docs]
74
  metadatas = [{"source": doc.metadata.get("source", "unknown")} for doc in split_docs]
75
 
 
76
  try:
77
  embedding_function = HuggingFaceEmbeddings(
78
  model_name="sentence-transformers/all-MiniLM-L6-v2"
@@ -89,64 +340,62 @@ class TradingChatbot:
89
  except Exception as e:
90
  print(f"Error initializing vectorstore: {e}")
91
 
92
- def _call_llm(self, prompt: str) -> str:
93
- """Call the HuggingFace model."""
94
- try:
95
- messages = [
96
- {"role": "system", "content": "You are an AI trading advisor in the TradeVerse financial ecosystem."},
97
- {"role": "user", "content": prompt}
98
- ]
99
- response = self.client.chat_completion(
100
- model=MODEL,
101
- messages=messages,
102
- max_tokens=512,
103
- temperature=0.7
104
- )
105
- return response.choices[0].message.content.strip()
106
- except Exception as e:
107
- print(f"LLM error: {e}")
108
- return "I'm having trouble processing that request. Please try again."
109
-
110
- def _get_confidence_framing(self, level: int) -> dict:
111
  """Get language framing based on confidence parameter."""
112
  if level < 34:
113
- return {"prefix": "Based on the available information,", "verb": "might consider", "qualifier": "though there is uncertainty", "level": "low"}
 
 
 
 
 
114
  elif level < 67:
115
- return {"prefix": "Looking at the situation,", "verb": "suggests", "qualifier": "while noting some risk factors", "level": "medium"}
 
 
 
 
 
116
  else:
117
- return {"prefix": "Based on my analysis,", "verb": "strongly recommend", "qualifier": "with high confidence", "level": "high"}
 
 
 
 
 
118
 
119
  def _get_depth_instructions(self, level: int) -> str:
120
- """Get explanation depth instructions."""
121
  if level < 34:
122
- return "Provide a brief response (1-2 sentences)."
123
  elif level < 67:
124
- return "Provide a moderate explanation (3-4 sentences)."
125
  else:
126
- return "Provide a detailed analysis covering all relevant factors."
127
 
128
  def _get_risk_framing(self, level: int) -> str:
129
- """Get risk perspective."""
130
  if level < 34:
131
- return "Emphasize potential risks. Favor capital preservation."
132
  elif level < 67:
133
- return "Balance risks and opportunities."
134
  else:
135
- return "Emphasize potential opportunities. Tolerate higher risk."
136
 
137
  def _get_style_instructions(self, level: int) -> str:
138
- """Get communication style."""
139
  if level < 34:
140
- return "Use formal, professional language."
141
  elif level < 67:
142
- return "Use clear, accessible language."
143
  else:
144
- return "Use conversational, friendly language."
145
 
146
  def _retrieve_context(self, query: str, k: int = 4) -> str:
147
  """Retrieve relevant context from the knowledge base."""
148
  if not self.vectorstore:
149
  return ""
 
150
  try:
151
  docs = self.vectorstore.similarity_search(query, k=k)
152
  return "\n\n".join([doc.page_content for doc in docs])
@@ -160,37 +409,70 @@ class TradingChatbot:
160
  visible_params: ParticipantVisibleParams,
161
  hidden_params: ResearcherControlledParams
162
  ) -> Optional[ChatResponse]:
163
- """Generate proactive advice for a scenario."""
164
- if random.random() > hidden_params.proactivity_level / 100:
 
 
 
 
 
165
  return None
166
 
 
167
  confidence = self._get_confidence_framing(hidden_params.confidence_framing)
168
- context = self._retrieve_context(f"{scenario.company_name} {scenario.sector}")
 
 
 
 
 
 
 
169
 
170
- factors = scenario.red_flags[:2] if hidden_params.risk_bias < 50 else scenario.positive_signals[:2]
171
- if not factors:
172
- factors = scenario.key_factors[:2]
 
 
 
 
 
 
 
 
173
 
174
- prompt = f"""{self._get_style_instructions(visible_params.communication_style)}
175
- {self._get_depth_instructions(visible_params.explanation_depth)}
176
 
177
  Company: {scenario.company_name} ({scenario.company_symbol})
178
  Sector: {scenario.sector}
179
- Price: {scenario.current_price} credits
 
180
 
181
- Situation: {scenario.situation_description}
 
182
 
183
- Key factors: {', '.join(factors)}
184
 
185
- Context: {context}
 
186
 
187
- Generate a brief proactive observation about this situation. Don't say BUY/SELL/HOLD yet - just note what's interesting."""
 
 
 
 
 
 
 
 
 
 
188
 
189
  return ChatResponse(
190
- message=self._call_llm(prompt),
191
  is_proactive=True,
192
  confidence_level=confidence["level"],
193
- sources_used=["market_context"]
194
  )
195
 
196
  def generate_ai_recommendation(
@@ -199,35 +481,55 @@ Generate a brief proactive observation about this situation. Don't say BUY/SELL/
199
  visible_params: ParticipantVisibleParams,
200
  hidden_params: ResearcherControlledParams
201
  ) -> ChatResponse:
202
- """Generate the AI's recommendation for a scenario."""
 
 
 
203
  confidence = self._get_confidence_framing(hidden_params.confidence_framing)
204
- context = self._retrieve_context(f"{scenario.company_name} {scenario.sector}")
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- prompt = f"""{self._get_style_instructions(visible_params.communication_style)}
207
- {self._get_depth_instructions(visible_params.explanation_depth)}
208
- {self._get_risk_framing(hidden_params.risk_bias)}
209
 
210
  Company: {scenario.company_name} ({scenario.company_symbol})
211
  Sector: {scenario.sector}
212
- Price: {scenario.current_price} credits
 
213
 
214
- Situation: {scenario.situation_description}
 
215
 
216
  Key factors: {', '.join(scenario.key_factors)}
217
- Warnings: {', '.join(scenario.red_flags) if scenario.red_flags else 'None'}
218
- Positives: {', '.join(scenario.positive_signals) if scenario.positive_signals else 'None'}
219
 
220
- Context: {context}
 
221
 
222
- {confidence['prefix']} I {confidence['verb']} to {scenario.ai_recommendation} {confidence['qualifier']}.
223
 
224
- Give your recommendation clearly stating {scenario.ai_recommendation}. Explain your reasoning."""
 
 
 
 
 
225
 
226
  return ChatResponse(
227
- message=self._call_llm(prompt),
228
  is_proactive=False,
229
  confidence_level=confidence["level"],
230
- sources_used=["market_context", "company_profile"]
231
  )
232
 
233
  def answer_query(
@@ -237,38 +539,76 @@ Give your recommendation clearly stating {scenario.ai_recommendation}. Explain y
237
  visible_params: ParticipantVisibleParams,
238
  hidden_params: ResearcherControlledParams
239
  ) -> ChatResponse:
240
- """Answer a participant's question."""
 
 
 
 
 
241
  confidence = self._get_confidence_framing(hidden_params.confidence_framing)
 
 
242
  context = self._retrieve_context(query)
243
 
244
- scenario_info = ""
 
245
  if scenario:
246
- scenario_info = f"Current scenario: {scenario.company_name} ({scenario.company_symbol}) - {scenario.situation_description}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- prompt = f"""{self._get_style_instructions(visible_params.communication_style)}
249
- {self._get_depth_instructions(visible_params.explanation_depth)}
250
 
251
- {scenario_info}
252
 
253
- Knowledge: {context}
 
254
 
255
- Question: {query}
 
 
 
 
 
 
 
 
 
256
 
257
- Answer helpfully. Only use TradeVerse information (fictional universe)."""
258
 
259
- response = self._call_llm(prompt)
260
- self.chat_history.append((query, response))
261
 
262
  return ChatResponse(
263
- message=response,
264
  is_proactive=False,
265
  confidence_level=confidence["level"],
266
  sources_used=["knowledge_base"]
267
  )
268
 
269
  def clear_history(self):
270
- """Clear the chat history."""
271
  self.chat_history = []
272
 
273
 
 
274
  chatbot = TradingChatbot()
 
1
  """
2
+ Chatbot engine with RAG pipeline and proactive/reactive logic.
3
+ Adapted for the AI Trading Experiment with parameter-aware responses.
4
+
5
+ Supports multiple LLM providers:
6
+ - HuggingFace Inference API (free tier available)
7
+ - DeepSeek API
8
+ - Fallback rule-based responses
9
  """
10
 
11
  import os
12
  import sys
13
  import random
14
+ import requests
15
+ from typing import Optional, List, Tuple, Dict, Any
16
  from dataclasses import dataclass
17
+ from enum import Enum
18
 
19
+ # Attempt pysqlite3 workaround for HF Spaces
20
  try:
21
  __import__('pysqlite3')
22
  sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
23
  except ImportError:
24
+ pass # Not in HF Spaces, use default sqlite3
25
 
26
+ # LangChain imports - using langchain_community for newer versions
27
  from langchain_community.document_loaders import TextLoader
28
  from langchain_community.vectorstores import Chroma
29
  from langchain_community.embeddings import HuggingFaceEmbeddings
30
  from langchain_text_splitters import CharacterTextSplitter
31
+ from langchain_core.language_models.llms import LLM
32
 
33
  from config import Scenario, ResearcherControlledParams, ParticipantVisibleParams
34
 
35
 
36
+ # Configuration
37
  KNOWLEDGE_BASE_DIR = "knowledge_base"
38
  VECTOR_DB_DIR = "db/vectorstore"
 
39
 
40
 
41
+ class LLMProvider(Enum):
42
+ """Available LLM providers."""
43
+ HUGGINGFACE = "huggingface"
44
+ DEEPSEEK = "deepseek"
45
+ FALLBACK = "fallback"
46
+
47
+
48
+ # ==================== LLM Provider Selection ====================
49
+
50
+ # Check which API keys are available and select provider
51
+ def get_llm_provider() -> LLMProvider:
52
+ """Determine which LLM provider to use based on available credentials."""
53
+ if os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_API_KEY"):
54
+ return LLMProvider.HUGGINGFACE
55
+ elif os.getenv("DEEPSEEK_API_KEY"):
56
+ return LLMProvider.DEEPSEEK
57
+ else:
58
+ print("Warning: No LLM API key found. Using fallback responses.")
59
+ print("Set HF_TOKEN for HuggingFace or DEEPSEEK_API_KEY for DeepSeek.")
60
+ return LLMProvider.FALLBACK
61
+
62
+
63
+ # ==================== HuggingFace LLM ====================
64
+
65
+ # Recommended free/cheap models (smallest to largest):
66
+ # - "microsoft/Phi-3-mini-4k-instruct" # 3.8B params, very fast
67
+ # - "Qwen/Qwen2-1.5B-Instruct" # 1.5B params, smallest
68
+ # - "HuggingFaceH4/zephyr-7b-beta" # 7B params, good quality
69
+ # - "mistralai/Mistral-7B-Instruct-v0.2" # 7B params, popular
70
+ # - "meta-llama/Llama-2-7b-chat-hf" # 7B params, requires approval
71
+
72
+ DEFAULT_HF_MODEL = "HuggingFaceH4/zephyr-7b-beta"
73
+
74
+
75
+ class HuggingFaceLLM(LLM):
76
+ """LLM wrapper for HuggingFace Inference API (free tier available)."""
77
+ api_key: str = ""
78
+ model_id: str = DEFAULT_HF_MODEL
79
+ temperature: float = 0.7
80
+ max_tokens: int = 512
81
+
82
+ def __init__(self, model_id: str = None, **kwargs):
83
+ super().__init__(**kwargs)
84
+ # Try multiple possible env var names
85
+ self.api_key = (
86
+ os.getenv("HF_TOKEN") or
87
+ os.getenv("HUGGINGFACE_TOKEN") or
88
+ os.getenv("HF_API_KEY") or
89
+ ""
90
+ )
91
+ if model_id:
92
+ self.model_id = model_id
93
+
94
+ if self.api_key:
95
+ print(f"Using HuggingFace model: {self.model_id}")
96
+ else:
97
+ print("Warning: No HuggingFace token found.")
98
+
99
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
100
+ if not self.api_key:
101
+ return self._fallback_response(prompt)
102
+
103
+ # HuggingFace Inference API endpoint
104
+ api_url = f"https://api-inference.huggingface.co/models/{self.model_id}"
105
+
106
+ headers = {
107
+ "Authorization": f"Bearer {self.api_key}",
108
+ "Content-Type": "application/json"
109
+ }
110
+
111
+ # Format prompt for instruction-tuned models
112
+ formatted_prompt = f"""<|system|>
113
+ You are an AI trading advisor in the TradeVerse financial ecosystem. Provide helpful, concise advice.</s>
114
+ <|user|>
115
+ {prompt}</s>
116
+ <|assistant|>
117
+ """
118
+
119
+ payload = {
120
+ "inputs": formatted_prompt,
121
+ "parameters": {
122
+ "max_new_tokens": self.max_tokens,
123
+ "temperature": self.temperature,
124
+ "do_sample": True,
125
+ "return_full_text": False
126
+ }
127
+ }
128
+
129
+ try:
130
+ response = requests.post(api_url, headers=headers, json=payload, timeout=60)
131
+
132
+ # Handle model loading (HF free tier may need to load model)
133
+ if response.status_code == 503:
134
+ data = response.json()
135
+ wait_time = data.get("estimated_time", 20)
136
+ print(f"Model loading, waiting {wait_time}s...")
137
+ import time
138
+ time.sleep(min(wait_time, 30))
139
+ response = requests.post(api_url, headers=headers, json=payload, timeout=60)
140
+
141
+ response.raise_for_status()
142
+ data = response.json()
143
+
144
+ # Handle different response formats
145
+ if isinstance(data, list) and len(data) > 0:
146
+ return data[0].get("generated_text", "").strip()
147
+ elif isinstance(data, dict):
148
+ return data.get("generated_text", "").strip()
149
+
150
+ return self._fallback_response(prompt)
151
+
152
+ except Exception as e:
153
+ print(f"HuggingFace API error: {e}")
154
+ return self._fallback_response(prompt)
155
+
156
+ def _fallback_response(self, prompt: str) -> str:
157
+ """Generate a basic response when API is unavailable."""
158
+ return FallbackLLM()._call(prompt)
159
+
160
+ @property
161
+ def _llm_type(self) -> str:
162
+ return "huggingface_inference"
163
+
164
+
165
+ # ==================== DeepSeek LLM ====================
166
+
167
+ DEEPSEEK_API_URL = "https://api.deepseek.com/v1/chat/completions"
168
+
169
+
170
+ class DeepSeekLLM(LLM):
171
+ """LLM wrapper for DeepSeek API."""
172
+ api_key: str = ""
173
+ temperature: float = 0.7
174
+ max_tokens: int = 512
175
+
176
+ def __init__(self, **kwargs):
177
+ super().__init__(**kwargs)
178
+ self.api_key = os.getenv("DEEPSEEK_API_KEY", "")
179
+ if self.api_key:
180
+ print("Using DeepSeek API")
181
+
182
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
183
+ if not self.api_key:
184
+ return FallbackLLM()._call(prompt)
185
+
186
+ headers = {
187
+ "Authorization": f"Bearer {self.api_key}",
188
+ "Content-Type": "application/json"
189
+ }
190
+ payload = {
191
+ "model": "deepseek-chat",
192
+ "messages": [
193
+ {"role": "system", "content": "You are an AI trading advisor in the TradeVerse financial ecosystem."},
194
+ {"role": "user", "content": prompt}
195
+ ],
196
+ "temperature": self.temperature,
197
+ "max_tokens": self.max_tokens
198
+ }
199
+
200
+ try:
201
+ response = requests.post(DEEPSEEK_API_URL, headers=headers, json=payload, timeout=30)
202
+ response.raise_for_status()
203
+ data = response.json()
204
+ return data["choices"][0]["message"]["content"].strip()
205
+ except Exception as e:
206
+ print(f"DeepSeek API error: {e}")
207
+ return FallbackLLM()._call(prompt)
208
+
209
+ @property
210
+ def _llm_type(self) -> str:
211
+ return "deepseek_api"
212
+
213
+
214
+ # ==================== Fallback LLM (Rule-based) ====================
215
+
216
+ class FallbackLLM(LLM):
217
+ """
218
+ Rule-based fallback when no API is available.
219
+ Generates responses based on scenario context and parameters.
220
+ """
221
+
222
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
223
+ """Generate a context-aware response without an LLM."""
224
+ prompt_lower = prompt.lower()
225
+
226
+ # Detect recommendation requests
227
+ if "buy" in prompt_lower and "recommend" in prompt_lower:
228
+ return self._generate_buy_response(prompt)
229
+ elif "sell" in prompt_lower and "recommend" in prompt_lower:
230
+ return self._generate_sell_response(prompt)
231
+ elif "hold" in prompt_lower and "recommend" in prompt_lower:
232
+ return self._generate_hold_response(prompt)
233
+
234
+ # Detect question types
235
+ if "risk" in prompt_lower:
236
+ return "When evaluating risk, consider the company's debt levels, market volatility, and any red flags like insider selling or unusual trading volume. The current scenario presents factors that warrant careful consideration."
237
+
238
+ if "insider" in prompt_lower or "trading volume" in prompt_lower:
239
+ return "Unusual insider activity or trading volume can signal that informed parties have information not yet public. This is often a warning sign that warrants caution."
240
+
241
+ if "sector" in prompt_lower or "industry" in prompt_lower:
242
+ return "Sector trends significantly impact individual companies. Consider broader market conditions, regulatory environment, and competitive dynamics when making your decision."
243
+
244
+ # Default analytical response
245
+ return "Based on the available information, I'd encourage you to weigh the key factors mentioned in the scenario. Consider both the potential opportunities and the risk factors before making your decision."
246
+
247
+ def _generate_buy_response(self, prompt: str) -> str:
248
+ return "Based on my analysis, buying could be appropriate here. The positive signals suggest potential upside, though you should consider your risk tolerance and the size of your position carefully."
249
+
250
+ def _generate_sell_response(self, prompt: str) -> str:
251
+ return "Based on my analysis, selling may be prudent. The risk factors present suggest potential downside that could outweigh staying invested. Consider protecting your capital."
252
+
253
+ def _generate_hold_response(self, prompt: str) -> str:
254
+ return "Based on my analysis, holding your position seems reasonable. The situation shows mixed signals, and waiting for more clarity before acting could be the wisest approach."
255
+
256
+ @property
257
+ def _llm_type(self) -> str:
258
+ return "fallback_rules"
259
+
260
+
261
+ # ==================== LLM Factory ====================
262
+
263
+ def create_llm(provider: LLMProvider = None, model_id: str = None) -> LLM:
264
+ """Factory function to create the appropriate LLM instance."""
265
+ if provider is None:
266
+ provider = get_llm_provider()
267
+
268
+ if provider == LLMProvider.HUGGINGFACE:
269
+ return HuggingFaceLLM(model_id=model_id)
270
+ elif provider == LLMProvider.DEEPSEEK:
271
+ return DeepSeekLLM()
272
+ else:
273
+ return FallbackLLM()
274
+
275
+
276
+ # ==================== Chat Response ====================
277
+
278
  @dataclass
279
  class ChatResponse:
280
  """Response from the chatbot."""
281
  message: str
282
  is_proactive: bool
283
+ confidence_level: str # "low", "medium", "high"
284
  sources_used: List[str]
285
 
286
 
287
+ # ==================== Trading Chatbot ====================
288
+
289
  class TradingChatbot:
290
+ """
291
+ AI Chatbot for the trading experiment.
292
+ Supports both proactive advice and reactive queries.
293
+ """
294
 
295
+ def __init__(self, llm_provider: LLMProvider = None, model_id: str = None):
296
+ self.llm = create_llm(llm_provider, model_id)
297
  self.vectorstore = None
298
  self.chat_history: List[Tuple[str, str]] = []
299
  self._initialize_knowledge_base()
 
300
 
301
  def _initialize_knowledge_base(self):
302
  """Load and index the knowledge base documents."""
 
316
  print("Warning: No knowledge base documents found.")
317
  return
318
 
319
+ # Split documents into chunks
320
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
321
  split_docs = splitter.split_documents(docs)
322
 
323
  texts = [doc.page_content for doc in split_docs]
324
  metadatas = [{"source": doc.metadata.get("source", "unknown")} for doc in split_docs]
325
 
326
+ # Create embeddings and vectorstore
327
  try:
328
  embedding_function = HuggingFaceEmbeddings(
329
  model_name="sentence-transformers/all-MiniLM-L6-v2"
 
340
  except Exception as e:
341
  print(f"Error initializing vectorstore: {e}")
342
 
343
+ def _get_confidence_framing(self, level: int) -> Dict[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  """Get language framing based on confidence parameter."""
345
  if level < 34:
346
+ return {
347
+ "prefix": "Based on the available information, one possibility is that",
348
+ "verb": "might consider",
349
+ "qualifier": "though there is considerable uncertainty",
350
+ "level": "low"
351
+ }
352
  elif level < 67:
353
+ return {
354
+ "prefix": "Looking at the situation,",
355
+ "verb": "suggests",
356
+ "qualifier": "while noting some risk factors",
357
+ "level": "medium"
358
+ }
359
  else:
360
+ return {
361
+ "prefix": "Based on my analysis,",
362
+ "verb": "strongly recommend",
363
+ "qualifier": "with high confidence",
364
+ "level": "high"
365
+ }
366
 
367
  def _get_depth_instructions(self, level: int) -> str:
368
+ """Get explanation depth instructions based on parameter."""
369
  if level < 34:
370
+ return "Provide a very brief response (1-2 sentences maximum). Focus only on the key point."
371
  elif level < 67:
372
+ return "Provide a moderate explanation (3-4 sentences). Include the main reasoning and key factors."
373
  else:
374
+ return "Provide a detailed analysis. Cover all relevant factors, risks, and opportunities comprehensively."
375
 
376
  def _get_risk_framing(self, level: int) -> str:
377
+ """Get risk perspective based on parameter."""
378
  if level < 34:
379
+ return "Emphasize potential risks and downside scenarios. Favor capital preservation over potential gains."
380
  elif level < 67:
381
+ return "Balance potential risks and opportunities. Present a measured risk-reward analysis."
382
  else:
383
+ return "Emphasize potential opportunities and upside. Be willing to tolerate higher risk for potential gains."
384
 
385
  def _get_style_instructions(self, level: int) -> str:
386
+ """Get communication style instructions based on parameter."""
387
  if level < 34:
388
+ return "Use formal, professional language. Be precise and measured in your statements."
389
  elif level < 67:
390
+ return "Use clear, accessible language. Be professional but approachable."
391
  else:
392
+ return "Use conversational, friendly language. Be direct and engaging."
393
 
394
  def _retrieve_context(self, query: str, k: int = 4) -> str:
395
  """Retrieve relevant context from the knowledge base."""
396
  if not self.vectorstore:
397
  return ""
398
+
399
  try:
400
  docs = self.vectorstore.similarity_search(query, k=k)
401
  return "\n\n".join([doc.page_content for doc in docs])
 
409
  visible_params: ParticipantVisibleParams,
410
  hidden_params: ResearcherControlledParams
411
  ) -> Optional[ChatResponse]:
412
+ """
413
+ Generate proactive advice for a scenario.
414
+ Returns None if proactive advice should not be shown.
415
+ """
416
+ # Check if we should show proactive advice based on proactivity level
417
+ proactive_threshold = hidden_params.proactivity_level / 100
418
+ if random.random() > proactive_threshold:
419
  return None
420
 
421
+ # Build the prompt
422
  confidence = self._get_confidence_framing(hidden_params.confidence_framing)
423
+ depth = self._get_depth_instructions(visible_params.explanation_depth)
424
+ risk = self._get_risk_framing(hidden_params.risk_bias)
425
+ style = self._get_style_instructions(visible_params.communication_style)
426
+
427
+ # Retrieve relevant context
428
+ context = self._retrieve_context(
429
+ f"{scenario.company_name} {scenario.sector} trading analysis"
430
+ )
431
 
432
+ # Determine what factors to highlight
433
+ factors_to_mention = []
434
+ if hidden_params.risk_bias < 50:
435
+ factors_to_mention = scenario.red_flags[:2] if scenario.red_flags else scenario.key_factors[:2]
436
+ else:
437
+ factors_to_mention = scenario.positive_signals[:2] if scenario.positive_signals else scenario.key_factors[:2]
438
+
439
+ prompt = f"""
440
+ {style}
441
+ {depth}
442
+ {risk}
443
 
444
+ You are an AI trading advisor. A participant is viewing this trading scenario:
 
445
 
446
  Company: {scenario.company_name} ({scenario.company_symbol})
447
  Sector: {scenario.sector}
448
+ Country: {scenario.country}
449
+ Current Price: {scenario.current_price} credits
450
 
451
+ Situation:
452
+ {scenario.situation_description}
453
 
454
+ Key factors to consider: {', '.join(factors_to_mention)}
455
 
456
+ Relevant knowledge:
457
+ {context}
458
 
459
+ You should proactively offer some initial observations about this situation.
460
+ {confidence['prefix']} the situation {confidence['verb']} careful attention {confidence['qualifier']}.
461
+
462
+ Your recommendation should lean toward: {scenario.ai_recommendation}
463
+
464
+ Generate a brief proactive message offering your initial take on this situation.
465
+ Do NOT explicitly tell them to BUY, SELL, or HOLD yet - this is an initial observation.
466
+ Keep it natural, as if you're an advisor noticing something they should be aware of.
467
+ """
468
+
469
+ response_text = self.llm._call(prompt)
470
 
471
  return ChatResponse(
472
+ message=response_text,
473
  is_proactive=True,
474
  confidence_level=confidence["level"],
475
+ sources_used=["market_context", "company_profile"]
476
  )
477
 
478
  def generate_ai_recommendation(
 
481
  visible_params: ParticipantVisibleParams,
482
  hidden_params: ResearcherControlledParams
483
  ) -> ChatResponse:
484
+ """
485
+ Generate the AI's recommendation for a scenario.
486
+ This is the main advice given before the participant decides.
487
+ """
488
  confidence = self._get_confidence_framing(hidden_params.confidence_framing)
489
+ depth = self._get_depth_instructions(visible_params.explanation_depth)
490
+ risk = self._get_risk_framing(hidden_params.risk_bias)
491
+ style = self._get_style_instructions(visible_params.communication_style)
492
+
493
+ context = self._retrieve_context(
494
+ f"{scenario.company_name} {scenario.sector} {scenario.ai_recommendation}"
495
+ )
496
+
497
+ prompt = f"""
498
+ {style}
499
+ {depth}
500
+ {risk}
501
 
502
+ You are an AI trading advisor. Analyze this situation and provide your recommendation:
 
 
503
 
504
  Company: {scenario.company_name} ({scenario.company_symbol})
505
  Sector: {scenario.sector}
506
+ Country: {scenario.country}
507
+ Current Price: {scenario.current_price} credits
508
 
509
+ Situation:
510
+ {scenario.situation_description}
511
 
512
  Key factors: {', '.join(scenario.key_factors)}
513
+ Warning signs: {', '.join(scenario.red_flags) if scenario.red_flags else 'None identified'}
514
+ Positive signals: {', '.join(scenario.positive_signals) if scenario.positive_signals else 'None identified'}
515
 
516
+ Relevant market knowledge:
517
+ {context}
518
 
519
+ {confidence['prefix']} I {confidence['verb']} the participant to {scenario.ai_recommendation} {confidence['qualifier']}.
520
 
521
+ Generate your recommendation. Be clear about your suggested action ({scenario.ai_recommendation}).
522
+ Explain your reasoning according to the depth level specified.
523
+ Frame risks according to the risk perspective specified.
524
+ """
525
+
526
+ response_text = self.llm._call(prompt)
527
 
528
  return ChatResponse(
529
+ message=response_text,
530
  is_proactive=False,
531
  confidence_level=confidence["level"],
532
+ sources_used=["market_context", "company_profile", "trading_basics"]
533
  )
534
 
535
  def answer_query(
 
539
  visible_params: ParticipantVisibleParams,
540
  hidden_params: ResearcherControlledParams
541
  ) -> ChatResponse:
542
+ """
543
+ Answer a participant's question (reactive query).
544
+ """
545
+ depth = self._get_depth_instructions(visible_params.explanation_depth)
546
+ style = self._get_style_instructions(visible_params.communication_style)
547
+ risk = self._get_risk_framing(hidden_params.risk_bias)
548
  confidence = self._get_confidence_framing(hidden_params.confidence_framing)
549
+
550
+ # Retrieve context based on the query
551
  context = self._retrieve_context(query)
552
 
553
+ # Build scenario context if available
554
+ scenario_context = ""
555
  if scenario:
556
+ scenario_context = f"""
557
+ Current scenario:
558
+ Company: {scenario.company_name} ({scenario.company_symbol})
559
+ Sector: {scenario.sector}
560
+ Situation: {scenario.situation_description}
561
+ """
562
+
563
+ # Include chat history for context
564
+ history_context = ""
565
+ if self.chat_history:
566
+ recent_history = self.chat_history[-3:] # Last 3 exchanges
567
+ history_context = "Recent conversation:\n" + "\n".join(
568
+ [f"User: {q}\nAI: {a}" for q, a in recent_history]
569
+ )
570
+
571
+ prompt = f"""
572
+ {style}
573
+ {depth}
574
+ {risk}
575
+
576
+ You are an AI trading advisor in the TradeVerse. Answer the participant's question.
577
 
578
+ {scenario_context}
 
579
 
580
+ {history_context}
581
 
582
+ Relevant knowledge from your database:
583
+ {context}
584
 
585
+ User question: {query}
586
+
587
+ Guidelines:
588
+ - Only use information from the TradeVerse (fictional universe)
589
+ - If asked about real-world companies or markets, politely redirect to TradeVerse
590
+ - {confidence['prefix'].lower()} frame your response {confidence['qualifier']}
591
+ - Be helpful but don't make decisions for the participant
592
+
593
+ Provide your response:
594
+ """
595
 
596
+ response_text = self.llm._call(prompt)
597
 
598
+ # Update chat history
599
+ self.chat_history.append((query, response_text))
600
 
601
  return ChatResponse(
602
+ message=response_text,
603
  is_proactive=False,
604
  confidence_level=confidence["level"],
605
  sources_used=["knowledge_base"]
606
  )
607
 
608
  def clear_history(self):
609
+ """Clear the chat history for a new session."""
610
  self.chat_history = []
611
 
612
 
613
+ # Singleton instance (uses auto-detected provider)
614
  chatbot = TradingChatbot()