frdel commited on
Commit
f3f8ca5
·
1 Parent(s): 551c95a

Merge branch 'pr/491' into development

Browse files
agent.py CHANGED
@@ -343,19 +343,28 @@ class Agent:
343
  # create log message right away, more responsive
344
  self.loop_data.params_temporary["log_item_generating"] = (
345
  self.context.log.log(
346
- type="agent", heading=f"{self.agent_name}: Thinking..."
347
  )
348
  )
349
 
 
 
 
 
 
 
350
  async def stream_callback(chunk: str, full: str):
351
  # output the agent response stream
352
- if chunk:
353
- printer.stream(chunk)
354
- await self.handle_response_stream(full)
355
-
356
- agent_response = await self.call_chat_model(
357
- prompt, callback=stream_callback
358
- ) # type: ignore
 
 
 
359
 
360
  await self.handle_intervention(agent_response)
361
 
@@ -409,7 +418,7 @@ class Agent:
409
  # call monologue_end extensions
410
  await self.call_extensions("monologue_end", loop_data=self.loop_data) # type: ignore
411
 
412
- async def prepare_prompt(self, loop_data: LoopData) -> ChatPromptTemplate:
413
  self.context.log.set_progress("Building prompt")
414
 
415
  # call extensions before setting prompts
@@ -422,12 +431,10 @@ class Agent:
422
  # and allow extensions to edit them
423
  await self.call_extensions("message_loop_prompts_after", loop_data=loop_data)
424
 
425
- # extras (memory etc.)
426
- # extras: list[history.OutputMessage] = []
427
- # for extra in loop_data.extras_persistent.values():
428
- # extras += history.Message(False, content=extra).output()
429
- # for extra in loop_data.extras_temporary.values():
430
- # extras += history.Message(False, content=extra).output()
431
  extras = history.Message(
432
  False,
433
  content=self.read_prompt(
@@ -444,28 +451,23 @@ class Agent:
444
  loop_data.history_output + extras
445
  )
446
 
447
- # build chain from system prompt, message history and model
448
- system_text = "\n\n".join(loop_data.system)
449
- prompt = ChatPromptTemplate.from_messages(
450
- [
451
- SystemMessage(content=system_text),
452
- *history_langchain,
453
- # AIMessage(content="JSON:"), # force the LLM to start with json
454
- ]
455
- )
456
 
457
  # store as last context window content
458
  self.set_data(
459
  Agent.DATA_NAME_CTX_WINDOW,
460
  {
461
- "text": prompt.format(),
462
- "tokens": self.history.get_tokens()
463
- + tokens.approximate_tokens(system_text)
464
- + tokens.approximate_tokens(history.output_text(extras)),
465
  },
466
  )
467
 
468
- return prompt
469
 
470
  def handle_critical_exception(self, exception: Exception):
471
  if isinstance(exception, HandledException):
@@ -586,24 +588,21 @@ class Agent:
586
  return self.history.output_text(human_label="user", ai_label="assistant")
587
 
588
  def get_chat_model(self):
589
- return models.get_model(
590
- models.ModelType.CHAT,
591
  self.config.chat_model.provider,
592
  self.config.chat_model.name,
593
  **self.config.chat_model.kwargs,
594
  )
595
 
596
  def get_utility_model(self):
597
- return models.get_model(
598
- models.ModelType.CHAT,
599
  self.config.utility_model.provider,
600
  self.config.utility_model.name,
601
  **self.config.utility_model.kwargs,
602
  )
603
 
604
  def get_embedding_model(self):
605
- return models.get_model(
606
- models.ModelType.EMBEDDING,
607
  self.config.embeddings_model.provider,
608
  self.config.embeddings_model.name,
609
  **self.config.embeddings_model.kwargs,
@@ -616,36 +615,37 @@ class Agent:
616
  callback: Callable[[str], Awaitable[None]] | None = None,
617
  background: bool = False,
618
  ):
619
- prompt = ChatPromptTemplate.from_messages(
620
- [SystemMessage(content=system), HumanMessage(content=message)]
621
- )
622
-
623
- response = ""
624
-
625
- # model class
626
  model = self.get_utility_model()
627
 
628
  # rate limiter
629
  limiter = await self.rate_limiter(
630
- self.config.utility_model, prompt.format(), background
631
  )
632
 
633
- async for chunk in (prompt | model).astream({}):
634
- await self.handle_intervention() # wait for intervention and handle it, if paused
635
-
636
- content = models.parse_chunk(chunk)
637
- limiter.add(output=tokens.approximate_tokens(content))
638
- response += content
639
 
 
 
640
  if callback:
641
- await callback(content)
 
 
 
 
 
 
 
642
 
643
  return response
644
 
645
  async def call_chat_model(
646
  self,
647
- prompt: ChatPromptTemplate,
648
- callback: Callable[[str, str], Awaitable[None]] | None = None,
 
649
  ):
650
  response = ""
651
 
@@ -653,19 +653,24 @@ class Agent:
653
  model = self.get_chat_model()
654
 
655
  # rate limiter
656
- limiter = await self.rate_limiter(self.config.chat_model, prompt.format())
657
-
658
- async for chunk in (prompt | model).astream({}):
659
- await self.handle_intervention() # wait for intervention and handle it, if paused
660
-
661
- content = models.parse_chunk(chunk)
662
- limiter.add(output=tokens.approximate_tokens(content))
663
- response += content
664
 
665
- if callback:
666
- await callback(content, response)
 
 
 
 
 
 
 
 
 
 
667
 
668
- return response
669
 
670
  async def rate_limiter(
671
  self, model_config: ModelConfig, input: str, background: bool = False
@@ -786,6 +791,13 @@ class Agent:
786
  content=f"{self.agent_name}: Message misformat, no valid tool request found.",
787
  )
788
 
 
 
 
 
 
 
 
789
  async def handle_response_stream(self, stream: str):
790
  try:
791
  if len(stream) < 25:
 
343
  # create log message right away, more responsive
344
  self.loop_data.params_temporary["log_item_generating"] = (
345
  self.context.log.log(
346
+ type="agent", heading=f"{self.agent_name}: Generating..."
347
  )
348
  )
349
 
350
+ async def reasoning_callback(chunk: str, full: str):
351
+ if chunk == full:
352
+ printer.print("Reasoning: ") # start of reasoning
353
+ printer.stream(chunk)
354
+ await self.handle_reasoning_stream(full)
355
+
356
  async def stream_callback(chunk: str, full: str):
357
  # output the agent response stream
358
+ if chunk == full:
359
+ printer.print("Response: ") # start of response
360
+ printer.stream(chunk)
361
+ await self.handle_response_stream(full)
362
+
363
+ agent_response, _reasoning = await self.call_chat_model(
364
+ messages=prompt,
365
+ response_callback=stream_callback,
366
+ reasoning_callback=reasoning_callback,
367
+ )
368
 
369
  await self.handle_intervention(agent_response)
370
 
 
418
  # call monologue_end extensions
419
  await self.call_extensions("monologue_end", loop_data=self.loop_data) # type: ignore
420
 
421
+ async def prepare_prompt(self, loop_data: LoopData) -> list[BaseMessage]:
422
  self.context.log.set_progress("Building prompt")
423
 
424
  # call extensions before setting prompts
 
431
  # and allow extensions to edit them
432
  await self.call_extensions("message_loop_prompts_after", loop_data=loop_data)
433
 
434
+ # concatenate system prompt
435
+ system_text = "\n\n".join(loop_data.system)
436
+
437
+ # join extras
 
 
438
  extras = history.Message(
439
  False,
440
  content=self.read_prompt(
 
451
  loop_data.history_output + extras
452
  )
453
 
454
+ # build full prompt from system prompt, message history and extrS
455
+ full_prompt: list[BaseMessage] = [
456
+ SystemMessage(content=system_text),
457
+ *history_langchain,
458
+ ]
459
+ full_text = ChatPromptTemplate.from_messages(full_prompt).format()
 
 
 
460
 
461
  # store as last context window content
462
  self.set_data(
463
  Agent.DATA_NAME_CTX_WINDOW,
464
  {
465
+ "text": full_text,
466
+ "tokens": tokens.approximate_tokens(full_text),
 
 
467
  },
468
  )
469
 
470
+ return full_prompt
471
 
472
  def handle_critical_exception(self, exception: Exception):
473
  if isinstance(exception, HandledException):
 
588
  return self.history.output_text(human_label="user", ai_label="assistant")
589
 
590
  def get_chat_model(self):
591
+ return models.get_chat_model(
 
592
  self.config.chat_model.provider,
593
  self.config.chat_model.name,
594
  **self.config.chat_model.kwargs,
595
  )
596
 
597
  def get_utility_model(self):
598
+ return models.get_chat_model(
 
599
  self.config.utility_model.provider,
600
  self.config.utility_model.name,
601
  **self.config.utility_model.kwargs,
602
  )
603
 
604
  def get_embedding_model(self):
605
+ return models.get_embedding_model(
 
606
  self.config.embeddings_model.provider,
607
  self.config.embeddings_model.name,
608
  **self.config.embeddings_model.kwargs,
 
615
  callback: Callable[[str], Awaitable[None]] | None = None,
616
  background: bool = False,
617
  ):
 
 
 
 
 
 
 
618
  model = self.get_utility_model()
619
 
620
  # rate limiter
621
  limiter = await self.rate_limiter(
622
+ self.config.utility_model, f"SYSTEM: {system}\nUSER: {message}", background
623
  )
624
 
625
+ # add output tokens to rate limiter in tokens callback
626
+ async def tokens_callback(delta: str, tokens: int):
627
+ await self.handle_intervention()
628
+ limiter.add(output=tokens)
 
 
629
 
630
+ # propagate stream to callback if set
631
+ async def stream_callback(chunk: str, total: str):
632
  if callback:
633
+ await callback(chunk)
634
+
635
+ response, _reasoning = await model.unified_call(
636
+ system_message=system,
637
+ user_message=message,
638
+ response_callback=stream_callback,
639
+ tokens_callback=tokens_callback,
640
+ )
641
 
642
  return response
643
 
644
  async def call_chat_model(
645
  self,
646
+ messages: list[BaseMessage],
647
+ response_callback: Callable[[str, str], Awaitable[None]] | None = None,
648
+ reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
649
  ):
650
  response = ""
651
 
 
653
  model = self.get_chat_model()
654
 
655
  # rate limiter
656
+ limiter = await self.rate_limiter(
657
+ self.config.chat_model, ChatPromptTemplate.from_messages(messages).format()
658
+ )
 
 
 
 
 
659
 
660
+ # add output tokens to rate limiter in tokens callback
661
+ async def tokens_callback(delta: str, tokens: int):
662
+ await self.handle_intervention()
663
+ limiter.add(output=tokens)
664
+
665
+ # call model
666
+ response, reasoning = await model.unified_call(
667
+ messages=messages,
668
+ reasoning_callback=reasoning_callback,
669
+ response_callback=response_callback,
670
+ tokens_callback=tokens_callback,
671
+ )
672
 
673
+ return response, reasoning
674
 
675
  async def rate_limiter(
676
  self, model_config: ModelConfig, input: str, background: bool = False
 
791
  content=f"{self.agent_name}: Message misformat, no valid tool request found.",
792
  )
793
 
794
+ async def handle_reasoning_stream(self, stream: str):
795
+ await self.call_extensions(
796
+ "reasoning_stream",
797
+ loop_data=self.loop_data,
798
+ text=stream,
799
+ )
800
+
801
  async def handle_response_stream(self, stream: str):
802
  try:
803
  if len(stream) < 25:
initialize.py CHANGED
@@ -7,6 +7,24 @@ from python.helpers.print_style import PrintStyle
7
  def initialize_agent():
8
  current_settings = settings.get_settings()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # chat model from user settings
11
  chat_llm = ModelConfig(
12
  provider=models.ModelProvider[current_settings["chat_model_provider"]],
@@ -16,7 +34,7 @@ def initialize_agent():
16
  limit_requests=current_settings["chat_model_rl_requests"],
17
  limit_input=current_settings["chat_model_rl_input"],
18
  limit_output=current_settings["chat_model_rl_output"],
19
- kwargs=current_settings["chat_model_kwargs"],
20
  )
21
 
22
  # utility model from user settings
@@ -27,21 +45,21 @@ def initialize_agent():
27
  limit_requests=current_settings["util_model_rl_requests"],
28
  limit_input=current_settings["util_model_rl_input"],
29
  limit_output=current_settings["util_model_rl_output"],
30
- kwargs=current_settings["util_model_kwargs"],
31
  )
32
  # embedding model from user settings
33
  embedding_llm = ModelConfig(
34
  provider=models.ModelProvider[current_settings["embed_model_provider"]],
35
  name=current_settings["embed_model_name"],
36
  limit_requests=current_settings["embed_model_rl_requests"],
37
- kwargs=current_settings["embed_model_kwargs"],
38
  )
39
  # browser model from user settings
40
  browser_llm = ModelConfig(
41
  provider=models.ModelProvider[current_settings["browser_model_provider"]],
42
  name=current_settings["browser_model_name"],
43
  vision=current_settings["browser_model_vision"],
44
- kwargs=current_settings["browser_model_kwargs"],
45
  )
46
  # agent configuration
47
  config = AgentConfig(
 
7
  def initialize_agent():
8
  current_settings = settings.get_settings()
9
 
10
+ def _normalize_model_kwargs(kwargs: dict) -> dict:
11
+ # convert string values that represent valid Python numbers to numeric types
12
+ result = {}
13
+ for key, value in kwargs.items():
14
+ if isinstance(value, str):
15
+ # try to convert string to number if it's a valid Python number
16
+ try:
17
+ # try int first, then float
18
+ result[key] = int(value)
19
+ except ValueError:
20
+ try:
21
+ result[key] = float(value)
22
+ except ValueError:
23
+ result[key] = value
24
+ else:
25
+ result[key] = value
26
+ return result
27
+
28
  # chat model from user settings
29
  chat_llm = ModelConfig(
30
  provider=models.ModelProvider[current_settings["chat_model_provider"]],
 
34
  limit_requests=current_settings["chat_model_rl_requests"],
35
  limit_input=current_settings["chat_model_rl_input"],
36
  limit_output=current_settings["chat_model_rl_output"],
37
+ kwargs=_normalize_model_kwargs(current_settings["chat_model_kwargs"]),
38
  )
39
 
40
  # utility model from user settings
 
45
  limit_requests=current_settings["util_model_rl_requests"],
46
  limit_input=current_settings["util_model_rl_input"],
47
  limit_output=current_settings["util_model_rl_output"],
48
+ kwargs=_normalize_model_kwargs(current_settings["util_model_kwargs"]),
49
  )
50
  # embedding model from user settings
51
  embedding_llm = ModelConfig(
52
  provider=models.ModelProvider[current_settings["embed_model_provider"]],
53
  name=current_settings["embed_model_name"],
54
  limit_requests=current_settings["embed_model_rl_requests"],
55
+ kwargs=_normalize_model_kwargs(current_settings["embed_model_kwargs"]),
56
  )
57
  # browser model from user settings
58
  browser_llm = ModelConfig(
59
  provider=models.ModelProvider[current_settings["browser_model_provider"]],
60
  name=current_settings["browser_model_name"],
61
  vision=current_settings["browser_model_vision"],
62
+ kwargs=_normalize_model_kwargs(current_settings["browser_model_kwargs"]),
63
  )
64
  # agent configuration
65
  config = AgentConfig(
models.py CHANGED
@@ -1,38 +1,37 @@
1
  from enum import Enum
2
  import os
3
- from typing import Any
4
- from langchain_openai import (
5
- ChatOpenAI,
6
- OpenAI,
7
- OpenAIEmbeddings,
8
- AzureChatOpenAI,
9
- AzureOpenAIEmbeddings,
10
- AzureOpenAI,
 
 
11
  )
12
- from langchain_community.llms.ollama import Ollama
13
- from langchain_ollama import ChatOllama
14
- from langchain_community.embeddings import OllamaEmbeddings
15
- from langchain_anthropic import ChatAnthropic
16
- from langchain_groq import ChatGroq
17
- from langchain_huggingface import (
18
- HuggingFaceEmbeddings,
19
- ChatHuggingFace,
20
- HuggingFaceEndpoint,
21
- )
22
- from langchain_google_genai import (
23
- ChatGoogleGenerativeAI,
24
- HarmBlockThreshold,
25
- HarmCategory,
26
- embeddings as google_embeddings,
27
- )
28
- from langchain_mistralai import ChatMistralAI
29
 
30
- # from pydantic.v1.types import SecretStr
31
- from python.helpers import dotenv, runtime
32
  from python.helpers.dotenv import load_dotenv
33
  from python.helpers.rate_limiter import RateLimiter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # environment variables
36
  load_dotenv()
37
 
38
 
@@ -52,40 +51,71 @@ class ModelProvider(Enum):
52
  MISTRALAI = "Mistral AI"
53
  OLLAMA = "Ollama"
54
  OPENAI = "OpenAI"
55
- OPENAI_AZURE = "OpenAI Azure"
56
  OPENROUTER = "OpenRouter"
57
  SAMBANOVA = "Sambanova"
58
  OTHER = "Other"
59
 
60
 
 
 
 
 
 
 
 
61
  rate_limiters: dict[str, RateLimiter] = {}
62
 
63
 
64
- # Utility function to get API keys from environment variables
65
- def get_api_key(service):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return (
67
  dotenv.get_dotenv_value(f"API_KEY_{service.upper()}")
68
  or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY")
69
- or dotenv.get_dotenv_value(
70
- f"{service.upper()}_API_TOKEN"
71
- ) # Added for CHUTES_API_TOKEN
72
  or "None"
73
  )
74
 
75
 
76
- def get_model(type: ModelType, provider: ModelProvider, name: str, **kwargs):
77
- fnc_name = f"get_{provider.name.lower()}_{type.name.lower()}" # function name of model getter
78
- model = globals()[fnc_name](name, **kwargs) # call function by name
79
- return model
80
-
81
-
82
  def get_rate_limiter(
83
  provider: ModelProvider, name: str, requests: int, input: int, output: int
84
  ) -> RateLimiter:
85
- # get or create
86
  key = f"{provider.name}\\{name}"
87
  rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60))
88
- # always update
89
  limiter.limits["requests"] = requests or 0
90
  limiter.limits["input"] = input or 0
91
  limiter.limits["output"] = output or 0
@@ -102,332 +132,385 @@ def parse_chunk(chunk: Any):
102
  return content
103
 
104
 
105
- # Ollama models
106
- def get_ollama_base_url():
107
- return (
108
- dotenv.get_dotenv_value("OLLAMA_BASE_URL")
109
- or f"http://{runtime.get_local_url()}:11434"
110
- )
111
-
112
-
113
- def get_ollama_chat(
114
- model_name: str,
115
- base_url=None,
116
- num_ctx=8192,
117
- **kwargs,
118
- ):
119
- if not base_url:
120
- base_url = get_ollama_base_url()
121
- return ChatOllama(
122
- model=model_name,
123
- base_url=base_url,
124
- num_ctx=num_ctx,
125
- **kwargs,
126
  )
127
-
128
-
129
- def get_ollama_embedding(
130
- model_name: str,
131
- base_url=None,
132
- num_ctx=8192,
133
- **kwargs,
134
- ):
135
- if not base_url:
136
- base_url = get_ollama_base_url()
137
- return OllamaEmbeddings(
138
- model=model_name, base_url=base_url, num_ctx=num_ctx, **kwargs
139
  )
140
-
141
-
142
- # HuggingFace models
143
- def get_huggingface_chat(
144
- model_name: str,
145
- api_key=None,
146
- **kwargs,
147
- ):
148
- # different naming convention here
149
- if not api_key:
150
- api_key = get_api_key("huggingface") or os.environ["HUGGINGFACEHUB_API_TOKEN"]
151
-
152
- # Initialize the HuggingFaceEndpoint with the specified model and parameters
153
- llm = HuggingFaceEndpoint(
154
- repo_id=model_name,
155
- task="text-generation",
156
- do_sample=True,
157
- **kwargs,
158
- )
159
-
160
- # Initialize the ChatHuggingFace with the configured llm
161
- return ChatHuggingFace(llm=llm)
162
-
163
-
164
- def get_huggingface_embedding(model_name: str, **kwargs):
165
- return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
166
-
167
-
168
- # LM Studio and other OpenAI compatible interfaces
169
- def get_lmstudio_base_url():
170
- return (
171
- dotenv.get_dotenv_value("LM_STUDIO_BASE_URL")
172
- or f"http://{runtime.get_local_url()}:1234/v1"
173
- )
174
-
175
-
176
- def get_lmstudio_chat(
177
- model_name: str,
178
- base_url=None,
179
- **kwargs,
180
- ):
181
- if not base_url:
182
- base_url = get_lmstudio_base_url()
183
- return ChatOpenAI(model_name=model_name, base_url=base_url, api_key="none", **kwargs) # type: ignore
184
-
185
-
186
- def get_lmstudio_embedding(
187
- model_name: str,
188
- base_url=None,
189
- **kwargs,
190
- ):
191
- if not base_url:
192
- base_url = get_lmstudio_base_url()
193
- return OpenAIEmbeddings(model=model_name, api_key="none", base_url=base_url, check_embedding_ctx_length=False, **kwargs) # type: ignore
194
-
195
-
196
- # Anthropic models
197
- def get_anthropic_chat(
198
- model_name: str,
199
- api_key=None,
200
- base_url=None,
201
- **kwargs,
202
- ):
203
- if not api_key:
204
- api_key = get_api_key("anthropic")
205
- if not base_url:
206
- base_url = (
207
- dotenv.get_dotenv_value("ANTHROPIC_BASE_URL") or "https://api.anthropic.com"
 
 
 
 
 
 
208
  )
209
- return ChatAnthropic(model_name=model_name, api_key=api_key, base_url=base_url, **kwargs) # type: ignore
210
-
211
-
212
- # right now anthropic does not have embedding models, but that might change
213
- def get_anthropic_embedding(
214
- model_name: str,
215
- api_key=None,
216
- **kwargs,
217
- ):
218
- if not api_key:
219
- api_key = get_api_key("anthropic")
220
- return OpenAIEmbeddings(model=model_name, api_key=api_key, **kwargs) # type: ignore
221
-
222
-
223
- # OpenAI models
224
- def get_openai_chat(
225
- model_name: str,
226
- api_key=None,
227
- **kwargs,
228
- ):
229
- if not api_key:
230
- api_key = get_api_key("openai")
231
- return ChatOpenAI(model_name=model_name, api_key=api_key, **kwargs) # type: ignore
232
-
233
-
234
- def get_openai_embedding(model_name: str, api_key=None, **kwargs):
235
- if not api_key:
236
- api_key = get_api_key("openai")
237
- return OpenAIEmbeddings(model=model_name, api_key=api_key, **kwargs) # type: ignore
238
-
239
-
240
- def get_openai_azure_chat(
241
- deployment_name: str,
242
- api_key=None,
243
- azure_endpoint=None,
244
- **kwargs,
245
- ):
246
- if not api_key:
247
- api_key = get_api_key("openai_azure")
248
- if not azure_endpoint:
249
- azure_endpoint = dotenv.get_dotenv_value("OPENAI_AZURE_ENDPOINT")
250
- return AzureChatOpenAI(deployment_name=deployment_name, api_key=api_key, azure_endpoint=azure_endpoint, **kwargs) # type: ignore
251
-
252
-
253
- def get_openai_azure_embedding(
254
- deployment_name: str,
255
- api_key=None,
256
- azure_endpoint=None,
257
- **kwargs,
258
- ):
259
- if not api_key:
260
- api_key = get_api_key("openai_azure")
261
- if not azure_endpoint:
262
- azure_endpoint = dotenv.get_dotenv_value("OPENAI_AZURE_ENDPOINT")
263
- return AzureOpenAIEmbeddings(deployment_name=deployment_name, api_key=api_key, azure_endpoint=azure_endpoint, **kwargs) # type: ignore
264
-
265
-
266
- # Google models
267
- def get_google_chat(
268
- model_name: str,
269
- api_key=None,
270
- **kwargs,
271
- ):
272
- if not api_key:
273
- api_key = get_api_key("google")
274
- return ChatGoogleGenerativeAI(model=model_name, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE}, **kwargs) # type: ignore
275
-
276
-
277
- def get_google_embedding(
278
- model_name: str,
279
- api_key=None,
280
- **kwargs,
281
- ):
282
- if not api_key:
283
- api_key = get_api_key("google")
284
- return google_embeddings.GoogleGenerativeAIEmbeddings(model=model_name, google_api_key=api_key, **kwargs) # type: ignore
285
-
286
-
287
- # Mistral models
288
- def get_mistralai_chat(
289
- model_name: str,
290
- api_key=None,
291
- **kwargs,
292
- ):
293
- if not api_key:
294
- api_key = get_api_key("mistral")
295
- return ChatMistralAI(model=model_name, api_key=api_key, **kwargs) # type: ignore
296
-
297
-
298
- # Groq models
299
- def get_groq_chat(
300
- model_name: str,
301
- api_key=None,
302
- **kwargs,
303
- ):
304
- if not api_key:
305
- api_key = get_api_key("groq")
306
- return ChatGroq(model_name=model_name, api_key=api_key, **kwargs) # type: ignore
307
-
308
-
309
- # DeepSeek models
310
- def get_deepseek_chat(
311
- model_name: str,
312
- api_key=None,
313
- base_url=None,
314
- **kwargs,
315
- ):
316
- if not api_key:
317
- api_key = get_api_key("deepseek")
318
- if not base_url:
319
- base_url = (
320
- dotenv.get_dotenv_value("DEEPSEEK_BASE_URL") or "https://api.deepseek.com"
321
  )
322
 
323
- return ChatOpenAI(api_key=api_key, model=model_name, base_url=base_url, **kwargs) # type: ignore
324
-
325
-
326
- # OpenRouter models
327
- def get_openrouter_chat(
328
- model_name: str,
329
- api_key=None,
330
- base_url=None,
331
- **kwargs,
332
- ):
333
- if not api_key:
334
- api_key = get_api_key("openrouter")
335
- if not base_url:
336
- base_url = (
337
- dotenv.get_dotenv_value("OPEN_ROUTER_BASE_URL")
338
- or "https://openrouter.ai/api/v1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  )
340
- return ChatOpenAI(
341
- api_key=api_key, # type: ignore
342
- model=model_name,
343
- base_url=base_url,
344
- stream_usage=True,
345
- model_kwargs={
346
- "extra_headers": {
347
- "HTTP-Referer": "https://agent-zero.ai",
348
- "X-Title": "Agent Zero",
349
- }
350
- },
351
- **kwargs,
352
- )
353
 
354
 
355
- def get_openrouter_embedding(
356
- model_name: str,
357
- api_key=None,
358
- base_url=None,
359
- **kwargs,
360
  ):
361
- if not api_key:
362
- api_key = get_api_key("openrouter")
363
- if not base_url:
364
- base_url = (
365
- dotenv.get_dotenv_value("OPEN_ROUTER_BASE_URL")
366
- or "https://openrouter.ai/api/v1"
367
- )
368
- return OpenAIEmbeddings(model=model_name, api_key=api_key, base_url=base_url, **kwargs) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
 
371
- # Sambanova models
372
- def get_sambanova_chat(
373
- model_name: str,
374
- api_key=None,
375
- base_url=None,
376
- max_tokens=1024,
377
- **kwargs,
378
- ):
379
- if not api_key:
380
- api_key = get_api_key("sambanova")
381
- if not base_url:
382
- base_url = (
383
- dotenv.get_dotenv_value("SAMBANOVA_BASE_URL")
384
- or "https://fast-api.snova.ai/v1"
385
- )
386
- return ChatOpenAI(api_key=api_key, model=model_name, base_url=base_url, max_tokens=max_tokens, **kwargs) # type: ignore
387
 
388
 
389
- # right now sambanova does not have embedding models, but that might change
390
- def get_sambanova_embedding(
391
- model_name: str,
392
- api_key=None,
393
- base_url=None,
394
- **kwargs,
395
- ):
396
- if not api_key:
397
- api_key = get_api_key("sambanova")
398
- if not base_url:
399
- base_url = (
400
- dotenv.get_dotenv_value("SAMBANOVA_BASE_URL")
401
- or "https://fast-api.snova.ai/v1"
402
- )
403
- return OpenAIEmbeddings(model=model_name, api_key=api_key, base_url=base_url, **kwargs) # type: ignore
404
 
405
 
406
- # Other OpenAI compatible models
407
- def get_other_chat(
408
- model_name: str,
409
- api_key=None,
410
- base_url=None,
411
- **kwargs,
412
- ):
413
- return ChatOpenAI(api_key=api_key, model=model_name, base_url=base_url, **kwargs) # type: ignore
414
 
415
 
416
- def get_other_embedding(model_name: str, api_key=None, base_url=None, **kwargs):
417
- return OpenAIEmbeddings(model=model_name, api_key=api_key, base_url=base_url, **kwargs) # type: ignore
418
 
419
 
420
- # Chutes models
421
- def get_chutes_chat(
422
- model_name: str,
423
- api_key=None,
424
- base_url=None,
425
- **kwargs,
426
- ):
427
- if not api_key:
428
- api_key = get_api_key("chutes")
429
- if not base_url:
430
- base_url = (
431
- dotenv.get_dotenv_value("CHUTES_BASE_URL") or "https://llm.chutes.ai/v1"
432
- )
433
- return ChatOpenAI(api_key=api_key, model=model_name, base_url=base_url, **kwargs) # type: ignore
 
1
  from enum import Enum
2
  import os
3
+ from typing import (
4
+ Any,
5
+ Awaitable,
6
+ Callable,
7
+ List,
8
+ Optional,
9
+ Iterator,
10
+ AsyncIterator,
11
+ Tuple,
12
+ TypedDict,
13
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ from litellm import completion, acompletion, embedding
16
+ from python.helpers import dotenv
17
  from python.helpers.dotenv import load_dotenv
18
  from python.helpers.rate_limiter import RateLimiter
19
+ from python.helpers.tokens import approximate_tokens
20
+
21
+ from langchain_core.language_models.chat_models import SimpleChatModel
22
+ from langchain_core.outputs.chat_generation import ChatGenerationChunk
23
+ from langchain_core.callbacks.manager import (
24
+ CallbackManagerForLLMRun,
25
+ AsyncCallbackManagerForLLMRun,
26
+ )
27
+ from langchain_core.messages import (
28
+ BaseMessage,
29
+ AIMessageChunk,
30
+ HumanMessage,
31
+ SystemMessage,
32
+ )
33
+ from langchain.embeddings.base import Embeddings
34
 
 
35
  load_dotenv()
36
 
37
 
 
51
  MISTRALAI = "Mistral AI"
52
  OLLAMA = "Ollama"
53
  OPENAI = "OpenAI"
54
+ AZURE = "OpenAI Azure"
55
  OPENROUTER = "OpenRouter"
56
  SAMBANOVA = "Sambanova"
57
  OTHER = "Other"
58
 
59
 
60
+ class ChatChunk(TypedDict):
61
+ """Simplified response chunk for chat models."""
62
+
63
+ response_delta: str
64
+ reasoning_delta: str
65
+
66
+
67
  rate_limiters: dict[str, RateLimiter] = {}
68
 
69
 
70
+ def configure_litellm_environment():
71
+ env_mappings = {
72
+ "API_KEY_OPENAI": "OPENAI_API_KEY",
73
+ "API_KEY_ANTHROPIC": "ANTHROPIC_API_KEY",
74
+ "API_KEY_GROQ": "GROQ_API_KEY",
75
+ "API_KEY_GOOGLE": "GOOGLE_API_KEY",
76
+ "API_KEY_MISTRAL": "MISTRAL_API_KEY",
77
+ "API_KEY_OLLAMA": "OLLAMA_API_KEY",
78
+ "API_KEY_HUGGINGFACE": "HUGGINGFACE_API_KEY",
79
+ "API_KEY_OPENAI_AZURE": "AZURE_API_KEY",
80
+ "API_KEY_DEEPSEEK": "DEEPSEEK_API_KEY",
81
+ "API_KEY_SAMBANOVA": "SAMBANOVA_API_KEY",
82
+ }
83
+ base_url_mappings = {
84
+ "OPENAI_BASE_URL": "OPENAI_API_BASE",
85
+ "ANTHROPIC_BASE_URL": "ANTHROPIC_API_BASE",
86
+ "GROQ_BASE_URL": "GROQ_API_BASE",
87
+ "GOOGLE_BASE_URL": "GOOGLE_API_BASE",
88
+ "MISTRAL_BASE_URL": "MISTRAL_API_BASE",
89
+ "OLLAMA_BASE_URL": "OLLAMA_API_BASE",
90
+ "HUGGINGFACE_BASE_URL": "HUGGINGFACE_API_BASE",
91
+ "AZURE_BASE_URL": "AZURE_API_BASE",
92
+ "DEEPSEEK_BASE_URL": "DEEPSEEK_API_BASE",
93
+ "SAMBANOVA_BASE_URL": "SAMBANOVA_API_BASE",
94
+ }
95
+ for a0, llm in env_mappings.items():
96
+ val = dotenv.get_dotenv_value(a0)
97
+ if val and not os.getenv(llm):
98
+ os.environ[llm] = val
99
+ for a0_base, llm_base in base_url_mappings.items():
100
+ val = dotenv.get_dotenv_value(a0_base)
101
+ if val and not os.getenv(llm_base):
102
+ os.environ[llm_base] = val
103
+
104
+
105
+ def get_api_key(service: str) -> str:
106
  return (
107
  dotenv.get_dotenv_value(f"API_KEY_{service.upper()}")
108
  or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY")
109
+ or dotenv.get_dotenv_value(f"{service.upper()}_API_TOKEN")
 
 
110
  or "None"
111
  )
112
 
113
 
 
 
 
 
 
 
114
  def get_rate_limiter(
115
  provider: ModelProvider, name: str, requests: int, input: int, output: int
116
  ) -> RateLimiter:
 
117
  key = f"{provider.name}\\{name}"
118
  rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60))
 
119
  limiter.limits["requests"] = requests or 0
120
  limiter.limits["input"] = input or 0
121
  limiter.limits["output"] = output or 0
 
132
  return content
133
 
134
 
135
+ def _parse_chunk(chunk: Any) -> ChatChunk:
136
+ delta = chunk["choices"][0].get("delta", {})
137
+ message = chunk["choices"][0].get("model_extra", {}).get("message", {})
138
+ response_delta = (
139
+ delta.get("content", "")
140
+ if isinstance(delta, dict)
141
+ else getattr(delta, "content", "")
142
+ ) or (
143
+ message.get("content", "")
144
+ if isinstance(message, dict)
145
+ else getattr(message, "content", "")
 
 
 
 
 
 
 
 
 
 
146
  )
147
+ reasoning_delta = (
148
+ delta.get("reasoning_content", "")
149
+ if isinstance(delta, dict)
150
+ else getattr(delta, "reasoning_content", "")
 
 
 
 
 
 
 
 
151
  )
152
+ return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta)
153
+
154
+
155
+ class LiteLLMChatWrapper(SimpleChatModel):
156
+ model_name: str
157
+ provider: str
158
+ kwargs: dict = {}
159
+
160
+ def __init__(self, model: str, provider: str, **kwargs: Any):
161
+ model_value = f"{provider}/{model}"
162
+ super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore
163
+
164
+ @property
165
+ def _llm_type(self) -> str:
166
+ return "litellm-chat"
167
+
168
+ def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]:
169
+ result = []
170
+ # Map LangChain message types to LiteLLM roles
171
+ role_mapping = {
172
+ "human": "user",
173
+ "ai": "assistant",
174
+ "system": "system",
175
+ "tool": "tool",
176
+ }
177
+ for m in messages:
178
+ role = role_mapping.get(m.type, m.type)
179
+ message_dict = {"role": role, "content": m.content}
180
+
181
+ # Handle tool calls for AI messages
182
+ tool_calls = getattr(m, "tool_calls", None)
183
+ if tool_calls:
184
+ # Convert LangChain tool calls to LiteLLM format
185
+ new_tool_calls = []
186
+ for tool_call in tool_calls:
187
+ # Ensure arguments is a JSON string
188
+ args = tool_call["args"]
189
+ if isinstance(args, dict):
190
+ import json
191
+
192
+ args_str = json.dumps(args)
193
+ else:
194
+ args_str = str(args)
195
+
196
+ new_tool_calls.append(
197
+ {
198
+ "id": tool_call.get("id", ""),
199
+ "type": "function",
200
+ "function": {
201
+ "name": tool_call["name"],
202
+ "arguments": args_str,
203
+ },
204
+ }
205
+ )
206
+ message_dict["tool_calls"] = new_tool_calls
207
+
208
+ # Handle tool call ID for ToolMessage
209
+ tool_call_id = getattr(m, "tool_call_id", None)
210
+ if tool_call_id:
211
+ message_dict["tool_call_id"] = tool_call_id
212
+
213
+ result.append(message_dict)
214
+ return result
215
+
216
+ def _call(
217
+ self,
218
+ messages: List[BaseMessage],
219
+ stop: Optional[List[str]] = None,
220
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
221
+ **kwargs: Any,
222
+ ) -> str:
223
+ msgs = self._convert_messages(messages)
224
+ resp = completion(
225
+ model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs}
226
  )
227
+ parsed = _parse_chunk(resp)
228
+ return parsed["response_delta"]
229
+
230
+ def _stream(
231
+ self,
232
+ messages: List[BaseMessage],
233
+ stop: Optional[List[str]] = None,
234
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
235
+ **kwargs: Any,
236
+ ) -> Iterator[ChatGenerationChunk]:
237
+ msgs = self._convert_messages(messages)
238
+ for chunk in completion(
239
+ model=self.model_name,
240
+ messages=msgs,
241
+ stream=True,
242
+ stop=stop,
243
+ **{**self.kwargs, **kwargs},
244
+ ):
245
+ parsed = _parse_chunk(chunk)
246
+ # Only yield chunks with non-None content
247
+ if parsed["response_delta"]:
248
+ yield ChatGenerationChunk(
249
+ message=AIMessageChunk(content=parsed["response_delta"])
250
+ )
251
+
252
+ async def _astream(
253
+ self,
254
+ messages: List[BaseMessage],
255
+ stop: Optional[List[str]] = None,
256
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
257
+ **kwargs: Any,
258
+ ) -> AsyncIterator[ChatGenerationChunk]:
259
+ msgs = self._convert_messages(messages)
260
+ response = await acompletion(
261
+ model=self.model_name,
262
+ messages=msgs,
263
+ stream=True,
264
+ stop=stop,
265
+ **{**self.kwargs, **kwargs},
266
+ )
267
+ async for chunk in response: # type: ignore
268
+ parsed = _parse_chunk(chunk)
269
+ # Only yield chunks with non-None content
270
+ if parsed["response_delta"]:
271
+ yield ChatGenerationChunk(
272
+ message=AIMessageChunk(content=parsed["response_delta"])
273
+ )
274
+
275
+ async def unified_call(
276
+ self,
277
+ system_message="",
278
+ user_message="",
279
+ messages: List[BaseMessage] = [],
280
+ response_callback: Callable[[str, str], Awaitable[None]] | None = None,
281
+ reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
282
+ tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
283
+ **kwargs: Any,
284
+ ) -> Tuple[str, str]:
285
+ # construct messages
286
+ if system_message:
287
+ messages.insert(0, SystemMessage(content=system_message))
288
+ if user_message:
289
+ messages.append(HumanMessage(content=user_message))
290
+
291
+ # convert to litellm format
292
+ msgs_conv = self._convert_messages(messages)
293
+
294
+ # call model
295
+ _completion = await acompletion(
296
+ model=self.model_name,
297
+ messages=msgs_conv,
298
+ stream=True,
299
+ **{**self.kwargs, **kwargs},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  )
301
 
302
+ # results
303
+ reasoning = ""
304
+ response = ""
305
+
306
+ # iterate over chunks
307
+ async for chunk in _completion: # type: ignore
308
+ parsed = _parse_chunk(chunk)
309
+ # collect reasoning delta and call callbacks
310
+ if parsed["reasoning_delta"]:
311
+ reasoning += parsed["reasoning_delta"]
312
+ if reasoning_callback:
313
+ await reasoning_callback(parsed["reasoning_delta"], reasoning)
314
+ if tokens_callback:
315
+ await tokens_callback(
316
+ parsed["reasoning_delta"],
317
+ approximate_tokens(parsed["reasoning_delta"]),
318
+ )
319
+ # collect response delta and call callbacks
320
+ if parsed["response_delta"]:
321
+ response += parsed["response_delta"]
322
+ if response_callback:
323
+ await response_callback(parsed["response_delta"], response)
324
+ if tokens_callback:
325
+ await tokens_callback(
326
+ parsed["response_delta"],
327
+ approximate_tokens(parsed["response_delta"]),
328
+ )
329
+
330
+ # return complete results
331
+ return response, reasoning
332
+
333
+
334
+ class BrowserCompatibleChatWrapper(LiteLLMChatWrapper):
335
+ """
336
+ A wrapper for browser agent that can filter/sanitize messages
337
+ before sending them to the LLM.
338
+ """
339
+
340
+ def _call(
341
+ self,
342
+ messages: List[BaseMessage],
343
+ stop: Optional[List[str]] = None,
344
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
345
+ **kwargs: Any,
346
+ ) -> str:
347
+ # In the future, message filtering logic can be added here.
348
+ result = super()._call(messages, stop, run_manager, **kwargs)
349
+ return result
350
+
351
+ async def _astream(
352
+ self,
353
+ messages: List[BaseMessage],
354
+ stop: Optional[List[str]] = None,
355
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
356
+ **kwargs: Any,
357
+ ) -> AsyncIterator[ChatGenerationChunk]:
358
+ # In the future, message filtering logic can be added here.
359
+ async for chunk in super()._astream(messages, stop, run_manager, **kwargs):
360
+ yield chunk
361
+
362
+
363
+ class LiteLLMEmbeddingWrapper(Embeddings):
364
+ model_name: str
365
+ kwargs: dict = {}
366
+
367
+ def __init__(self, model: str, provider: str, **kwargs: Any):
368
+ self.model_name = f"{provider}/{model}" if provider != "openai" else model
369
+ self.kwargs = kwargs
370
+
371
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
372
+ resp = embedding(model=self.model_name, input=texts, **self.kwargs)
373
+ return [
374
+ item.get("embedding") if isinstance(item, dict) else item.embedding
375
+ for item in resp.data
376
+ ]
377
+
378
+ def embed_query(self, text: str) -> List[float]:
379
+ resp = embedding(model=self.model_name, input=[text], **self.kwargs)
380
+ item = resp.data[0]
381
+ return item.get("embedding") if isinstance(item, dict) else item.embedding
382
+
383
+
384
+ class LocalSentenceTransformerWrapper(Embeddings):
385
+ """Local wrapper for sentence-transformers models to avoid HuggingFace API calls"""
386
+
387
+ def __init__(self, model_name: str, **kwargs: Any):
388
+ try:
389
+ from sentence_transformers import SentenceTransformer
390
+ except ImportError:
391
+ raise ImportError(
392
+ "sentence-transformers library is required for local embeddings. Install with: pip install sentence-transformers"
393
+ )
394
+
395
+ # Remove the "sentence-transformers/" prefix if present
396
+ if model_name.startswith("sentence-transformers/"):
397
+ model_name = model_name[len("sentence-transformers/") :]
398
+
399
+ self.model = SentenceTransformer(model_name, **kwargs)
400
+ self.model_name = model_name
401
+
402
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
403
+ embeddings = self.model.encode(texts, convert_to_tensor=False)
404
+ return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings
405
+
406
+ def embed_query(self, text: str) -> List[float]:
407
+ embedding = self.model.encode([text], convert_to_tensor=False)
408
+ result = (
409
+ embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0]
410
  )
411
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
 
414
+ def _get_litellm_chat(
415
+ cls: type = LiteLLMChatWrapper,
416
+ model_name: str = "",
417
+ provider_name: str = "",
418
+ **kwargs: Any,
419
  ):
420
+ provider_name = provider_name.lower()
421
+
422
+ configure_litellm_environment()
423
+ # Use original provider name for API key lookup, fallback to mapped provider name
424
+ api_key = kwargs.pop("api_key", None) or get_api_key(provider_name)
425
+
426
+ # litellm will pick up base_url from env. We just need to control the api_key.
427
+ base_url = dotenv.get_dotenv_value(f"{provider_name.upper()}_BASE_URL")
428
+
429
+ # If a base_url is set, ensure api_key is not passed to litellm
430
+ if base_url:
431
+ if "api_key" in kwargs:
432
+ del kwargs["api_key"]
433
+ # Only pass API key if no base_url is set and key is not a placeholder
434
+ elif api_key and api_key not in ("None", "NA"):
435
+ kwargs["api_key"] = api_key
436
+
437
+ # for openrouter add app reference
438
+ if provider_name == "openrouter":
439
+ kwargs["extra_headers"] = {
440
+ "HTTP-Referer": "https://agent-zero.ai",
441
+ "X-Title": "Agent Zero",
442
+ }
443
+
444
+ return cls(model=model_name, provider=provider_name, **kwargs)
445
+
446
+
447
+ def get_litellm_embedding(model_name: str, provider: str, **kwargs: Any):
448
+ # Check if this is a local sentence-transformers model
449
+ if provider == "huggingface" and model_name.startswith("sentence-transformers/"):
450
+ # Use local sentence-transformers instead of LiteLLM for local models
451
+ return LocalSentenceTransformerWrapper(model_name=model_name, **kwargs)
452
+
453
+ configure_litellm_environment()
454
+ # Use original provider name for API key lookup, fallback to mapped provider name
455
+ api_key = kwargs.pop("api_key", None) or get_api_key(provider)
456
+
457
+ # litellm will pick up base_url from env. We just need to control the api_key.
458
+ base_url = dotenv.get_dotenv_value(f"{provider.upper()}_BASE_URL")
459
+
460
+ # If a base_url is set, ensure api_key is not passed to litellm
461
+ if base_url:
462
+ if "api_key" in kwargs:
463
+ del kwargs["api_key"]
464
+ # Only pass API key if no base_url is set and key is not a placeholder
465
+ elif api_key and api_key not in ("None", "NA"):
466
+ kwargs["api_key"] = api_key
467
+
468
+ return LiteLLMEmbeddingWrapper(model=model_name, provider=provider, **kwargs)
469
+
470
+
471
+ def get_model(type: ModelType, provider: ModelProvider, name: str, **kwargs: Any):
472
+ provider_name = provider.name.lower()
473
+ kwargs = _normalize_chat_kwargs(kwargs)
474
+ if type == ModelType.CHAT:
475
+ return _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
476
+ elif type == ModelType.EMBEDDING:
477
+ return get_litellm_embedding(name, provider_name, **kwargs)
478
+ else:
479
+ raise ValueError(f"Unsupported model type: {type}")
480
 
481
 
482
+ def get_chat_model(
483
+ provider: ModelProvider, name: str, **kwargs: Any
484
+ ) -> LiteLLMChatWrapper:
485
+ provider_name = provider.name.lower()
486
+ kwargs = _normalize_chat_kwargs(kwargs)
487
+ model = _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, **kwargs)
488
+ return model
 
 
 
 
 
 
 
 
 
489
 
490
 
491
+ def get_browser_model(
492
+ provider: ModelProvider, name: str, **kwargs: Any
493
+ ) -> BrowserCompatibleChatWrapper:
494
+ provider_name = provider.name.lower()
495
+ kwargs = _normalize_chat_kwargs(kwargs)
496
+ model = _get_litellm_chat(
497
+ BrowserCompatibleChatWrapper, name, provider_name, **kwargs
498
+ )
499
+ return model
 
 
 
 
 
 
500
 
501
 
502
+ def get_embedding_model(
503
+ provider: ModelProvider, name: str, **kwargs: Any
504
+ ) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:
505
+ provider_name = provider.name.lower()
506
+ kwargs = _normalize_embedding_kwargs(kwargs)
507
+ model = get_litellm_embedding(name, provider_name, **kwargs)
508
+ return model
 
509
 
510
 
511
+ def _normalize_chat_kwargs(kwargs: Any) -> Any:
512
+ return kwargs
513
 
514
 
515
+ def _normalize_embedding_kwargs(kwargs: Any) -> Any:
516
+ return kwargs
 
 
 
 
 
 
 
 
 
 
 
 
preload.py CHANGED
@@ -22,7 +22,11 @@ async def preload():
22
  async def preload_embedding():
23
  if set["embed_model_provider"] == models.ModelProvider.HUGGINGFACE.name:
24
  try:
25
- emb_mod = models.get_huggingface_embedding(set["embed_model_name"])
 
 
 
 
26
  emb_txt = await emb_mod.aembed_query("test")
27
  return emb_txt
28
  except Exception as e:
 
22
  async def preload_embedding():
23
  if set["embed_model_provider"] == models.ModelProvider.HUGGINGFACE.name:
24
  try:
25
+ # Use the new LiteLLM-based model system
26
+ emb_mod = models.get_embedding_model(
27
+ models.ModelProvider.HUGGINGFACE,
28
+ set["embed_model_name"]
29
+ )
30
  emb_txt = await emb_mod.aembed_query("test")
31
  return emb_txt
32
  except Exception as e:
prompts/agent0/agent.system.tool.response.md CHANGED
@@ -3,6 +3,7 @@ final answer to user
3
  ends task processing use only when done or no task active
4
  put result in text arg
5
  always use markdown formatting headers bold text lists
 
6
  use emojis as icons improve readability
7
  prefer using tables
8
  focus nice structured output key selling point
 
3
  ends task processing use only when done or no task active
4
  put result in text arg
5
  always use markdown formatting headers bold text lists
6
+ full message is automatically markdown do not wrap ~~~markdown
7
  use emojis as icons improve readability
8
  prefer using tables
9
  focus nice structured output key selling point
prompts/default/agent.system.tool.call_sub.md CHANGED
@@ -1,63 +1,26 @@
1
  ### call_subordinate
2
 
3
  you can use subordinates for subtasks
4
- subordinates can be specialized roles
5
- message field: always describe task details goal overview important details for new subordinate
6
  delegate specific subtasks not entire task
7
  reset arg usage:
8
  "true": spawn new subordinate
9
- "false": continue current conversation
10
- prompt_profile defines subordinate specialization
11
-
12
- #### if you are superior
13
- - identify new tasks which your main task's completion depends upon
14
- - break down your main task into subtasks if possible. If the task can not be split execute it yourself
15
- - only let saubtasks and new depended upon tasks of your main task be handled by subordinates
16
- - never forward your entire task to a subordinate to avoid endless delegation loops
17
-
18
- #### if you are subordinate:
19
- - superior is {{agent_name}} minus 1
20
- - execute the task you were assigned
21
- - delegate further if asked
22
- - break down tasks and delegate if necessary
23
- - do not delegate tasks you can accomplish yourself without refining them
24
- - only subtasks of your current main task are allowed to be delegated. Never delegate your entire task ro prevent endless loops.
25
-
26
- #### Arguments:
27
- - message (string): always describe task details goal overview important details for new subordinate
28
- - reset (boolean): true: spawn new subordinate, false: continue current conversation
29
- - prompt_profile (string): defines specialization, only available prompt profiles below, can omit when reset false
30
-
31
- ##### Prompt Profiles available
32
- {{prompt_profiles}}
33
-
34
- #### example usage
35
- ~~~json
36
- {
37
- "thoughts": [
38
- "This task is challenging and requires a data analyst",
39
- "The research_agent profile supports data analysis",
40
- ],
41
- "headline": "Delegating coding fix to subordinate agent",
42
- "tool_name": "call_subordinate",
43
- "tool_args": {
44
- "message": "...",
45
- "reset": "true",
46
- "prompt_profile": "research_agent",
47
- }
48
- }
49
- ~~~
50
 
 
51
  ~~~json
52
  {
53
  "thoughts": [
54
- "The response is missing...",
55
- "I will ask a subordinate to add...",
56
  ],
57
  "tool_name": "call_subordinate",
58
  "tool_args": {
59
  "message": "...",
60
- "reset": "false",
61
  }
62
  }
63
- ~~~
 
1
  ### call_subordinate
2
 
3
  you can use subordinates for subtasks
4
+ subordinates can be scientist coder engineer etc
5
+ message field: always describe role, task details goal overview for new subordinate
6
  delegate specific subtasks not entire task
7
  reset arg usage:
8
  "true": spawn new subordinate
9
+ "false": continue existing subordinate
10
+ if superior, orchestrate
11
+ respond to existing subordinates using call_subordinate tool with reset false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ example usage
14
  ~~~json
15
  {
16
  "thoughts": [
17
+ "The result seems to be ok but...",
18
+ "I will ask a coder subordinate to fix...",
19
  ],
20
  "tool_name": "call_subordinate",
21
  "tool_args": {
22
  "message": "...",
23
+ "reset": "true"
24
  }
25
  }
26
+ ~~~
python/extensions/reasoning_stream/.gitkeep ADDED
File without changes
python/extensions/reasoning_stream/_10_log_from_stream.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from python.helpers import persist_chat, tokens
2
+ from python.helpers.extension import Extension
3
+ from agent import LoopData
4
+ import asyncio
5
+ from python.helpers.log import LogItem
6
+ from python.helpers import log
7
+ import math
8
+
9
+
10
+ class LogFromStream(Extension):
11
+
12
+ async def execute(self, loop_data: LoopData = LoopData(), text: str = "", **kwargs):
13
+
14
+ # thought length indicator
15
+ length = math.ceil(len(text) / 10) * 10
16
+ heading = f"{self.agent.agent_name}: Reasoning ({length})..."
17
+
18
+ # create log message and store it in loop data temporary params
19
+ if "log_item_generating" not in loop_data.params_temporary:
20
+ loop_data.params_temporary["log_item_generating"] = (
21
+ self.agent.context.log.log(
22
+ type="agent",
23
+ heading=heading,
24
+ )
25
+ )
26
+
27
+ # update log message
28
+ log_item = loop_data.params_temporary["log_item_generating"]
29
+ log_item.update(heading=heading, reasoning=text)
python/extensions/response_stream/_10_log_from_stream.py CHANGED
@@ -4,6 +4,7 @@ from agent import LoopData
4
  import asyncio
5
  from python.helpers.log import LogItem
6
  from python.helpers import log
 
7
 
8
 
9
  class LogFromStream(Extension):
@@ -13,20 +14,38 @@ class LogFromStream(Extension):
13
  loop_data: LoopData = LoopData(),
14
  text: str = "",
15
  parsed: dict = {},
16
- **kwargs
17
  ):
18
 
19
- heading = f"{self.agent.agent_name}: Thinking..."
20
  if "headline" in parsed:
21
  heading = f"{self.agent.agent_name}: {parsed['headline']}"
 
 
 
 
 
 
 
 
22
 
23
  # create log message and store it in loop data temporary params
24
  if "log_item_generating" not in loop_data.params_temporary:
25
- loop_data.params_temporary["log_item_generating"] = self.agent.context.log.log(
26
- type="agent",
27
- heading=heading,
 
 
28
  )
29
 
30
  # update log message
31
  log_item = loop_data.params_temporary["log_item_generating"]
32
- log_item.update(heading=heading, content=text, kvps=parsed)
 
 
 
 
 
 
 
 
 
4
  import asyncio
5
  from python.helpers.log import LogItem
6
  from python.helpers import log
7
+ import math
8
 
9
 
10
  class LogFromStream(Extension):
 
14
  loop_data: LoopData = LoopData(),
15
  text: str = "",
16
  parsed: dict = {},
17
+ **kwargs,
18
  ):
19
 
20
+ heading = f"{self.agent.agent_name}: Generating..."
21
  if "headline" in parsed:
22
  heading = f"{self.agent.agent_name}: {parsed['headline']}"
23
+ elif "thoughts" in parsed:
24
+ # thought length indicator
25
+ thoughts = "\n".join(parsed["thoughts"])
26
+ length = math.ceil(len(thoughts) / 10) * 10
27
+ heading = f"{self.agent.agent_name}: Thinking ({length})..."
28
+
29
+ if "tool_name" in parsed:
30
+ heading += f" ({parsed['tool_name']})"
31
 
32
  # create log message and store it in loop data temporary params
33
  if "log_item_generating" not in loop_data.params_temporary:
34
+ loop_data.params_temporary["log_item_generating"] = (
35
+ self.agent.context.log.log(
36
+ type="agent",
37
+ heading=heading,
38
+ )
39
  )
40
 
41
  # update log message
42
  log_item = loop_data.params_temporary["log_item_generating"]
43
+
44
+ # keep reasoning from previous logs in kvps
45
+ kvps = {}
46
+ if log_item.kvps is not None and "reasoning" in log_item.kvps:
47
+ kvps["reasoning"] = log_item.kvps["reasoning"]
48
+ kvps.update(parsed)
49
+
50
+ # update the log item
51
+ log_item.update(heading=heading, content=text, kvps=kvps)
python/helpers/document_query.py CHANGED
@@ -42,6 +42,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
42
 
43
  DEFAULT_SEARCH_THRESHOLD = 0.5
44
 
 
45
  class DocumentQueryStore:
46
  """
47
  FAISS Store for document query results.
@@ -85,7 +86,7 @@ class DocumentQueryStore:
85
  Normalized URI
86
  """
87
  # Convert to lowercase
88
- normalized = uri.strip() # uri.lower()
89
 
90
  # Parse the URL to get scheme
91
  parsed = urlparse(normalized)
@@ -368,7 +369,9 @@ class DocumentQueryStore:
368
 
369
  class DocumentQueryHelper:
370
 
371
- def __init__(self, agent: Agent, progress_callback: Callable[[str], None] | None = None):
 
 
372
  self.agent = agent
373
  self.store = DocumentQueryStore.get(agent)
374
  self.progress_callback = progress_callback or (lambda x: None)
@@ -414,30 +417,34 @@ class DocumentQueryHelper:
414
  content = f"!!! No content found for document: {document_uri} matching queries: {json.dumps(questions)}"
415
  return False, content
416
 
417
- self.progress_callback(f"Processing {len(questions)} questions in context of {len(selected_chunks)} chunks")
 
 
418
 
419
  questions_str = "\n".join([f" * {question}" for question in questions])
420
- content = "\n\n----\n\n".join([chunk.page_content for chunk in selected_chunks.values()])
 
 
421
 
422
  qa_system_message = self.agent.parse_prompt(
423
  "fw.document_query.system_prompt.md"
424
  )
425
  qa_user_message = f"# Document:\n{content}\n\n# Queries:\n{questions_str}"
426
 
427
- ai_response = await self.agent.call_chat_model(
428
- prompt=ChatPromptTemplate.from_messages(
429
- [
430
- SystemMessage(content=qa_system_message),
431
- HumanMessage(content=qa_user_message),
432
- ]
433
- )
434
  )
435
 
436
  self.progress_callback(f"Q&A process completed")
437
 
438
  return True, str(ai_response)
439
 
440
- async def document_get_content(self, document_uri: str, add_to_db: bool = False) -> str:
 
 
441
  self.progress_callback(f"Fetching document content")
442
  url = urlparse(document_uri)
443
  scheme = url.scheme or "file"
@@ -518,7 +525,9 @@ class DocumentQueryHelper:
518
  )
519
  if add_to_db:
520
  self.progress_callback(f"Indexing document")
521
- success, ids = await self.store.add_document(document_content, document_uri_norm)
 
 
522
  if not success:
523
  self.progress_callback(f"Failed to index document")
524
  raise ValueError(
 
42
 
43
  DEFAULT_SEARCH_THRESHOLD = 0.5
44
 
45
+
46
  class DocumentQueryStore:
47
  """
48
  FAISS Store for document query results.
 
86
  Normalized URI
87
  """
88
  # Convert to lowercase
89
+ normalized = uri.strip() # uri.lower()
90
 
91
  # Parse the URL to get scheme
92
  parsed = urlparse(normalized)
 
369
 
370
  class DocumentQueryHelper:
371
 
372
+ def __init__(
373
+ self, agent: Agent, progress_callback: Callable[[str], None] | None = None
374
+ ):
375
  self.agent = agent
376
  self.store = DocumentQueryStore.get(agent)
377
  self.progress_callback = progress_callback or (lambda x: None)
 
417
  content = f"!!! No content found for document: {document_uri} matching queries: {json.dumps(questions)}"
418
  return False, content
419
 
420
+ self.progress_callback(
421
+ f"Processing {len(questions)} questions in context of {len(selected_chunks)} chunks"
422
+ )
423
 
424
  questions_str = "\n".join([f" * {question}" for question in questions])
425
+ content = "\n\n----\n\n".join(
426
+ [chunk.page_content for chunk in selected_chunks.values()]
427
+ )
428
 
429
  qa_system_message = self.agent.parse_prompt(
430
  "fw.document_query.system_prompt.md"
431
  )
432
  qa_user_message = f"# Document:\n{content}\n\n# Queries:\n{questions_str}"
433
 
434
+ ai_response, _reasoning = await self.agent.call_chat_model(
435
+ messages=[
436
+ SystemMessage(content=qa_system_message),
437
+ HumanMessage(content=qa_user_message),
438
+ ]
 
 
439
  )
440
 
441
  self.progress_callback(f"Q&A process completed")
442
 
443
  return True, str(ai_response)
444
 
445
+ async def document_get_content(
446
+ self, document_uri: str, add_to_db: bool = False
447
+ ) -> str:
448
  self.progress_callback(f"Fetching document content")
449
  url = urlparse(document_uri)
450
  scheme = url.scheme or "file"
 
525
  )
526
  if add_to_db:
527
  self.progress_callback(f"Indexing document")
528
+ success, ids = await self.store.add_document(
529
+ document_content, document_uri_norm
530
+ )
531
  if not success:
532
  self.progress_callback(f"Failed to index document")
533
  raise ValueError(
python/helpers/history.py CHANGED
@@ -534,10 +534,17 @@ def _merge_outputs(a: MessageContent, b: MessageContent) -> MessageContent:
534
  if isinstance(a, str) and isinstance(b, str):
535
  return a + "\n" + b
536
 
537
- if not isinstance(a, list):
538
- a = [a]
539
- if not isinstance(b, list):
540
- b = [b]
 
 
 
 
 
 
 
541
 
542
  return cast(MessageContent, a + b)
543
 
 
534
  if isinstance(a, str) and isinstance(b, str):
535
  return a + "\n" + b
536
 
537
+ def make_list(obj: MessageContent) -> list[MessageContent]:
538
+ if isinstance(obj, list):
539
+ return obj # type: ignore
540
+ if isinstance(obj, dict):
541
+ return [obj]
542
+ if isinstance(obj, str):
543
+ return [{"type": "text", "text": obj}]
544
+ return [obj]
545
+
546
+ a = make_list(a)
547
+ b = make_list(b)
548
 
549
  return cast(MessageContent, a + b)
550
 
python/helpers/memory.py CHANGED
@@ -117,8 +117,7 @@ class Memory:
117
  os.makedirs(em_dir, exist_ok=True)
118
  store = LocalFileStore(em_dir)
119
 
120
- embeddings_model = models.get_model(
121
- models.ModelType.EMBEDDING,
122
  model_config.provider,
123
  model_config.name,
124
  **model_config.kwargs,
 
117
  os.makedirs(em_dir, exist_ok=True)
118
  store = LocalFileStore(em_dir)
119
 
120
+ embeddings_model = models.get_embedding_model(
 
121
  model_config.provider,
122
  model_config.name,
123
  **model_config.kwargs,
python/tools/browser_agent.py CHANGED
@@ -1,5 +1,4 @@
1
  import asyncio
2
- import json
3
  import time
4
  from typing import Optional
5
  from agent import Agent, InterventionException
@@ -57,6 +56,8 @@ class State:
57
  screen={"width": 1024, "height": 2048},
58
  viewport={"width": 1024, "height": 2048},
59
  args=["--headless=new"],
 
 
60
  )
61
  )
62
 
@@ -118,25 +119,28 @@ class State:
118
  )
119
  return result
120
 
121
- model = models.get_model(
122
- type=models.ModelType.CHAT,
123
  provider=self.agent.config.browser_model.provider,
124
  name=self.agent.config.browser_model.name,
125
  **self.agent.config.browser_model.kwargs,
126
  )
127
 
128
- self.use_agent = browser_use.Agent(
129
- task=task,
130
- browser_session=self.browser_session,
131
- llm=model,
132
- use_vision=self.agent.config.browser_model.vision,
133
- extend_system_message=self.agent.read_prompt(
134
- "prompts/browser_agent.system.md"
135
- ),
136
- controller=controller,
137
- enable_memory=False, # Disable memory to avoid state conflicts
138
- # available_file_paths=[],
139
- )
 
 
 
140
 
141
  self.iter_no = get_iter_no(self.agent)
142
 
 
1
  import asyncio
 
2
  import time
3
  from typing import Optional
4
  from agent import Agent, InterventionException
 
56
  screen={"width": 1024, "height": 2048},
57
  viewport={"width": 1024, "height": 2048},
58
  args=["--headless=new"],
59
+ # Use a unique user data directory to avoid conflicts
60
+ user_data_dir=str(Path.home() / ".config" / "browseruse" / "profiles" / f"agent_{self.agent.context.id}"),
61
  )
62
  )
63
 
 
119
  )
120
  return result
121
 
122
+
123
+ model = models.get_browser_model(
124
  provider=self.agent.config.browser_model.provider,
125
  name=self.agent.config.browser_model.name,
126
  **self.agent.config.browser_model.kwargs,
127
  )
128
 
129
+ try:
130
+ self.use_agent = browser_use.Agent(
131
+ task=task,
132
+ browser_session=self.browser_session,
133
+ llm=model,
134
+ use_vision=self.agent.config.browser_model.vision,
135
+ extend_system_message=self.agent.read_prompt(
136
+ "prompts/browser_agent.system.md"
137
+ ),
138
+ controller=controller,
139
+ enable_memory=False, # Disable memory to avoid state conflicts
140
+ # available_file_paths=[],
141
+ )
142
+ except Exception as e:
143
+ raise Exception(f"Browser agent initialization failed. This might be due to model compatibility issues. Error: {e}") from e
144
 
145
  self.iter_no = get_iter_no(self.agent)
146
 
requirements.txt CHANGED
@@ -10,15 +10,7 @@ flask-basicauth==0.2.0
10
  flaredantic==0.1.4
11
  GitPython==3.1.43
12
  inputimeout==1.0.4
13
- langchain-anthropic==0.3.3
14
- langchain-community==0.3.19
15
- langchain-google-genai==2.1.2
16
- langchain-groq==0.2.2
17
- langchain-huggingface==0.1.2
18
- langchain-mistralai==0.2.4
19
- langchain-ollama==0.3.0
20
- langchain-openai==0.3.11
21
- langchain-unstructured[all-docs]==0.1.6
22
  openai-whisper==20240930
23
  lxml_html_clean==0.3.1
24
  markdown==3.7
@@ -35,6 +27,8 @@ unstructured[all-docs]==0.16.23
35
  unstructured-client==0.31.0
36
  webcolors==24.6.0
37
  nest-asyncio==1.6.0
 
 
38
  markdownify==1.1.0
39
  pymupdf==1.25.3
40
  pytesseract==0.3.13
 
10
  flaredantic==0.1.4
11
  GitPython==3.1.43
12
  inputimeout==1.0.4
13
+ langchain-core==0.3.49
 
 
 
 
 
 
 
 
14
  openai-whisper==20240930
15
  lxml_html_clean==0.3.1
16
  markdown==3.7
 
27
  unstructured-client==0.31.0
28
  webcolors==24.6.0
29
  nest-asyncio==1.6.0
30
+ crontab==1.0.1
31
+ litellm==1.72.4
32
  markdownify==1.1.0
33
  pymupdf==1.25.3
34
  pytesseract==0.3.13
test.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from os import sep
3
+ from langchain_core.messages import HumanMessage, SystemMessage
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ import models
6
+ from python.helpers import dotenv
7
+
8
+
9
+ async def test():
10
+
11
+ dotenv.load_dotenv()
12
+
13
+ # model_name = "moonshotai/kimi-dev-72b:free"
14
+ # model_name = "qwen/qwq-32b"
15
+ # model_name = "qwen/qwen3-32b"
16
+ # model_name = "anthropic/claude-3.7-sonnet:thinking"
17
+ model_name = "openai/gpt-4.1-nano"
18
+ system = ""
19
+ message = "hello"
20
+
21
+ model = models.get_chat_model(models.ModelProvider.OPENROUTER, model_name)
22
+
23
+ async def response_callback(chunk: str, full: str):
24
+ if chunk == full:
25
+ print("\n")
26
+ print("Response:")
27
+ print(chunk, end="", flush=True)
28
+
29
+ async def reasoning_callback(chunk: str, full: str):
30
+ if chunk == full:
31
+ print("\n")
32
+ print("Reasoning:")
33
+ print(chunk, end="", flush=True)
34
+
35
+ response, reasoning = await model.unified_call(
36
+ system_message=system,
37
+ user_message=message,
38
+ response_callback=response_callback,
39
+ reasoning_callback=reasoning_callback,
40
+ )
41
+
42
+ print("\n")
43
+ print("Final:")
44
+ print("Reasoning:", reasoning)
45
+ print("Response:", response)
46
+
47
+
48
+ async def test2():
49
+
50
+ dotenv.load_dotenv()
51
+
52
+ import initialize
53
+ config = initialize.initialize_agent()
54
+
55
+ model = models.get_browser_model(
56
+ provider=config.browser_model.provider,
57
+ name=config.browser_model.name,
58
+ **config.browser_model.kwargs,
59
+ )
60
+
61
+ response, reasoning = await model.unified_call(
62
+ system_message="",
63
+ user_message="hi",
64
+ )
65
+
66
+ print("\n")
67
+ print("Final:")
68
+ print("Reasoning:", reasoning)
69
+ print("Response:", response)
70
+
71
+
72
+ if __name__ == "__main__":
73
+ # asyncio.run(test())
74
+ asyncio.run(test2())
webui/index.css CHANGED
@@ -1572,6 +1572,11 @@ input:checked + .slider:before {
1572
  display: auto;
1573
  }
1574
 
 
 
 
 
 
1575
  .msg-content {
1576
  margin-bottom: 0;
1577
  }
 
1572
  display: auto;
1573
  }
1574
 
1575
+ .msg-thoughts .kvps-val {
1576
+ max-height: 20em;
1577
+ overflow: auto;
1578
+ }
1579
+
1580
  .msg-content {
1581
  margin-bottom: 0;
1582
  }
webui/js/messages.js CHANGED
@@ -537,7 +537,7 @@ function drawKvps(container, kvps, latex) {
537
  for (let [key, value] of Object.entries(kvps)) {
538
  const row = table.insertRow();
539
  row.classList.add("kvps-row");
540
- if (key === "thoughts" || key === "reflection")
541
  row.classList.add("msg-thoughts");
542
 
543
  const th = row.insertCell();
@@ -545,6 +545,9 @@ function drawKvps(container, kvps, latex) {
545
  th.classList.add("kvps-key");
546
 
547
  const td = row.insertCell();
 
 
 
548
 
549
  if (Array.isArray(value)) {
550
  for (const item of value) {
@@ -562,7 +565,7 @@ function drawKvps(container, kvps, latex) {
562
  imgElement.classList.add("kvps-img");
563
  imgElement.src = value.replace("img://", "/image_get?path=");
564
  imgElement.alt = "Image Attachment";
565
- td.appendChild(imgElement);
566
 
567
  // Add click handler and cursor change
568
  imgElement.style.cursor = "pointer";
@@ -570,15 +573,14 @@ function drawKvps(container, kvps, latex) {
570
  openImageModal(imgElement.src, 1000);
571
  });
572
 
573
- td.appendChild(imgElement);
574
  } else {
575
  const pre = document.createElement("pre");
576
- pre.classList.add("kvps-val");
577
  // if (row.classList.contains("msg-thoughts")) {
578
  const span = document.createElement("span");
579
  span.innerHTML = convertHTML(value);
580
  pre.appendChild(span);
581
- td.appendChild(pre);
582
  addCopyButtonToElement(row);
583
 
584
  // Add click handler
 
537
  for (let [key, value] of Object.entries(kvps)) {
538
  const row = table.insertRow();
539
  row.classList.add("kvps-row");
540
+ if (key === "thoughts" || key === "reasoning") // TODO: find a better way to determine special class assignment
541
  row.classList.add("msg-thoughts");
542
 
543
  const th = row.insertCell();
 
545
  th.classList.add("kvps-key");
546
 
547
  const td = row.insertCell();
548
+ const tdiv = document.createElement("div");
549
+ tdiv.classList.add("kvps-val");
550
+ td.appendChild(tdiv);
551
 
552
  if (Array.isArray(value)) {
553
  for (const item of value) {
 
565
  imgElement.classList.add("kvps-img");
566
  imgElement.src = value.replace("img://", "/image_get?path=");
567
  imgElement.alt = "Image Attachment";
568
+ tdiv.appendChild(imgElement);
569
 
570
  // Add click handler and cursor change
571
  imgElement.style.cursor = "pointer";
 
573
  openImageModal(imgElement.src, 1000);
574
  });
575
 
 
576
  } else {
577
  const pre = document.createElement("pre");
578
+ // pre.classList.add("kvps-val");
579
  // if (row.classList.contains("msg-thoughts")) {
580
  const span = document.createElement("span");
581
  span.innerHTML = convertHTML(value);
582
  pre.appendChild(span);
583
+ tdiv.appendChild(pre);
584
  addCopyButtonToElement(row);
585
 
586
  // Add click handler