frdel commited on
Commit
ea7bb7f
·
1 Parent(s): a7a3196

llms think tags handling

Browse files
Files changed (2) hide show
  1. models.py +225 -55
  2. tests/chunk_parser_test.py +23 -0
models.py CHANGED
@@ -53,7 +53,8 @@ def turn_off_logging():
53
  # init
54
  load_dotenv()
55
  turn_off_logging()
56
- litellm.modify_params = True # helps fix anthropic tool calls by browser-use
 
57
 
58
  class ModelType(Enum):
59
  CHAT = "Chat"
@@ -82,14 +83,116 @@ class ModelConfig:
82
 
83
  class ChatChunk(TypedDict):
84
  """Simplified response chunk for chat models."""
85
-
86
  response_delta: str
87
  reasoning_delta: str
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  rate_limiters: dict[str, RateLimiter] = {}
91
  api_keys_round_robin: dict[str, int] = {}
92
 
 
93
  def get_api_key(service: str) -> str:
94
  # get api key for the service
95
  key = (
@@ -116,7 +219,14 @@ def get_rate_limiter(
116
  limiter.limits["output"] = output or 0
117
  return limiter
118
 
119
- async def apply_rate_limiter(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None):
 
 
 
 
 
 
 
120
  if not model_config:
121
  return
122
  limiter = get_rate_limiter(
@@ -131,25 +241,41 @@ async def apply_rate_limiter(model_config: ModelConfig|None, input_text: str, ra
131
  await limiter.wait(rate_limiter_callback)
132
  return limiter
133
 
134
- def apply_rate_limiter_sync(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None):
 
 
 
 
 
 
 
135
  if not model_config:
136
  return
137
  import asyncio, nest_asyncio
 
138
  nest_asyncio.apply()
139
- return asyncio.run(apply_rate_limiter(model_config, input_text, rate_limiter_callback))
 
 
140
 
141
 
142
  class LiteLLMChatWrapper(SimpleChatModel):
143
  model_name: str
144
  provider: str
145
  kwargs: dict = {}
146
-
147
  class Config:
148
  arbitrary_types_allowed = True
149
  extra = "allow" # Allow extra attributes
150
  validate_assignment = False # Don't validate on assignment
151
 
152
- def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any):
 
 
 
 
 
 
153
  model_value = f"{provider}/{model}"
154
  super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore
155
  # Set A0 model config as instance attribute after parent init
@@ -158,7 +284,7 @@ class LiteLLMChatWrapper(SimpleChatModel):
158
  @property
159
  def _llm_type(self) -> str:
160
  return "litellm-chat"
161
-
162
  def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]:
163
  result = []
164
  # Map LangChain message types to LiteLLM roles
@@ -215,12 +341,12 @@ class LiteLLMChatWrapper(SimpleChatModel):
215
  **kwargs: Any,
216
  ) -> str:
217
  import asyncio
218
-
219
  msgs = self._convert_messages(messages)
220
-
221
  # Apply rate limiting if configured
222
  apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
223
-
224
  # Call the model
225
  resp = completion(
226
  model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs}
@@ -228,7 +354,8 @@ class LiteLLMChatWrapper(SimpleChatModel):
228
 
229
  # Parse output
230
  parsed = _parse_chunk(resp)
231
- return parsed["response_delta"]
 
232
 
233
  def _stream(
234
  self,
@@ -238,12 +365,14 @@ class LiteLLMChatWrapper(SimpleChatModel):
238
  **kwargs: Any,
239
  ) -> Iterator[ChatGenerationChunk]:
240
  import asyncio
241
-
242
  msgs = self._convert_messages(messages)
243
-
244
  # Apply rate limiting if configured
245
  apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
246
-
 
 
247
  for chunk in completion(
248
  model=self.model_name,
249
  messages=msgs,
@@ -251,11 +380,14 @@ class LiteLLMChatWrapper(SimpleChatModel):
251
  stop=stop,
252
  **{**self.kwargs, **kwargs},
253
  ):
254
- parsed = _parse_chunk(chunk)
 
 
 
255
  # Only yield chunks with non-None content
256
- if parsed["response_delta"]:
257
  yield ChatGenerationChunk(
258
- message=AIMessageChunk(content=parsed["response_delta"])
259
  )
260
 
261
  async def _astream(
@@ -266,11 +398,12 @@ class LiteLLMChatWrapper(SimpleChatModel):
266
  **kwargs: Any,
267
  ) -> AsyncIterator[ChatGenerationChunk]:
268
  msgs = self._convert_messages(messages)
269
-
270
  # Apply rate limiting if configured
271
  await apply_rate_limiter(self.a0_model_conf, str(msgs))
272
-
273
-
 
274
  response = await acompletion(
275
  model=self.model_name,
276
  messages=msgs,
@@ -279,11 +412,14 @@ class LiteLLMChatWrapper(SimpleChatModel):
279
  **{**self.kwargs, **kwargs},
280
  )
281
  async for chunk in response: # type: ignore
282
- parsed = _parse_chunk(chunk)
 
 
 
283
  # Only yield chunks with non-None content
284
- if parsed["response_delta"]:
285
  yield ChatGenerationChunk(
286
- message=AIMessageChunk(content=parsed["response_delta"])
287
  )
288
 
289
  async def unified_call(
@@ -294,7 +430,9 @@ class LiteLLMChatWrapper(SimpleChatModel):
294
  response_callback: Callable[[str, str], Awaitable[None]] | None = None,
295
  reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
296
  tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
297
- rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None,
 
 
298
  **kwargs: Any,
299
  ) -> Tuple[str, str]:
300
 
@@ -312,7 +450,9 @@ class LiteLLMChatWrapper(SimpleChatModel):
312
  msgs_conv = self._convert_messages(messages)
313
 
314
  # Apply rate limiting if configured
315
- limiter = await apply_rate_limiter(self.a0_model_conf, str(msgs_conv), rate_limiter_callback)
 
 
316
 
317
  # call model
318
  _completion = await acompletion(
@@ -323,41 +463,41 @@ class LiteLLMChatWrapper(SimpleChatModel):
323
  )
324
 
325
  # results
326
- reasoning = ""
327
- response = ""
328
 
329
  # iterate over chunks
330
  async for chunk in _completion: # type: ignore
 
331
  parsed = _parse_chunk(chunk)
 
 
332
  # collect reasoning delta and call callbacks
333
- if parsed["reasoning_delta"]:
334
- reasoning += parsed["reasoning_delta"]
335
  if reasoning_callback:
336
- await reasoning_callback(parsed["reasoning_delta"], reasoning)
337
  if tokens_callback:
338
  await tokens_callback(
339
- parsed["reasoning_delta"],
340
- approximate_tokens(parsed["reasoning_delta"]),
341
  )
342
  # Add output tokens to rate limiter if configured
343
  if limiter:
344
- limiter.add(output=approximate_tokens(parsed["reasoning_delta"]))
345
  # collect response delta and call callbacks
346
- if parsed["response_delta"]:
347
- response += parsed["response_delta"]
348
  if response_callback:
349
- await response_callback(parsed["response_delta"], response)
350
  if tokens_callback:
351
  await tokens_callback(
352
- parsed["response_delta"],
353
- approximate_tokens(parsed["response_delta"]),
354
  )
355
  # Add output tokens to rate limiter if configured
356
  if limiter:
357
- limiter.add(output=approximate_tokens(parsed["response_delta"]))
358
 
359
  # return complete results
360
- return response, reasoning
361
 
362
 
363
  class BrowserCompatibleChatWrapper(LiteLLMChatWrapper):
@@ -400,15 +540,21 @@ class LiteLLMEmbeddingWrapper(Embeddings):
400
  kwargs: dict = {}
401
  a0_model_conf: Optional[ModelConfig] = None
402
 
403
- def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any):
 
 
 
 
 
 
404
  self.model_name = f"{provider}/{model}" if provider != "openai" else model
405
  self.kwargs = kwargs
406
  self.a0_model_conf = model_config
407
-
408
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
409
  # Apply rate limiting if configured
410
  apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
411
-
412
  resp = embedding(model=self.model_name, input=texts, **self.kwargs)
413
  return [
414
  item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
@@ -418,7 +564,7 @@ class LiteLLMEmbeddingWrapper(Embeddings):
418
  def embed_query(self, text: str) -> List[float]:
419
  # Apply rate limiting if configured
420
  apply_rate_limiter_sync(self.a0_model_conf, text)
421
-
422
  resp = embedding(model=self.model_name, input=[text], **self.kwargs)
423
  item = resp.data[0] # type: ignore
424
  return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
@@ -427,7 +573,13 @@ class LiteLLMEmbeddingWrapper(Embeddings):
427
  class LocalSentenceTransformerWrapper(Embeddings):
428
  """Local wrapper for sentence-transformers models to avoid HuggingFace API calls"""
429
 
430
- def __init__(self, provider: str, model: str, model_config: Optional[ModelConfig] = None, **kwargs: Any):
 
 
 
 
 
 
431
  # Clean common user-input mistakes
432
  model = model.strip().strip('"').strip("'")
433
 
@@ -449,18 +601,18 @@ class LocalSentenceTransformerWrapper(Embeddings):
449
  self.model = SentenceTransformer(model, **st_kwargs)
450
  self.model_name = model
451
  self.a0_model_conf = model_config
452
-
453
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
454
  # Apply rate limiting if configured
455
  apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
456
-
457
  embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore
458
  return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore
459
 
460
  def embed_query(self, text: str) -> List[float]:
461
  # Apply rate limiting if configured
462
  apply_rate_limiter_sync(self.a0_model_conf, text)
463
-
464
  embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore
465
  result = (
466
  embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0]
@@ -485,10 +637,17 @@ def _get_litellm_chat(
485
  provider_name, model_name, kwargs = _adjust_call_args(
486
  provider_name, model_name, kwargs
487
  )
488
- return cls(provider=provider_name, model=model_name, model_config=model_config, **kwargs)
 
 
489
 
490
 
491
- def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any):
 
 
 
 
 
492
  # Check if this is a local sentence-transformers model
493
  if provider_name == "huggingface" and model_name.startswith(
494
  "sentence-transformers/"
@@ -498,7 +657,10 @@ def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Op
498
  provider_name, model_name, kwargs
499
  )
500
  return LocalSentenceTransformerWrapper(
501
- provider=provider_name, model=model_name, model_config=model_config, **kwargs
 
 
 
502
  )
503
 
504
  # use api key from kwargs or env
@@ -511,7 +673,9 @@ def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Op
511
  provider_name, model_name, kwargs = _adjust_call_args(
512
  provider_name, model_name, kwargs
513
  )
514
- return LiteLLMEmbeddingWrapper(model=model_name, provider=provider_name, model_config=model_config, **kwargs)
 
 
515
 
516
 
517
  def _parse_chunk(chunk: Any) -> ChatChunk:
@@ -533,9 +697,11 @@ def _parse_chunk(chunk: Any) -> ChatChunk:
533
  if isinstance(delta, dict)
534
  else getattr(delta, "reasoning_content", "")
535
  )
 
536
  return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta)
537
 
538
 
 
539
  def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict):
540
  # for openrouter add app reference
541
  if provider_name == "openrouter":
@@ -599,10 +765,14 @@ def _merge_provider_defaults(
599
  return provider_name, kwargs
600
 
601
 
602
- def get_chat_model(provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any) -> LiteLLMChatWrapper:
 
 
603
  orig = provider.lower()
604
  provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs)
605
- return _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, model_config, **kwargs)
 
 
606
 
607
 
608
  def get_browser_model(
 
53
  # init
54
  load_dotenv()
55
  turn_off_logging()
56
+ litellm.modify_params = True # helps fix anthropic tool calls by browser-use
57
+
58
 
59
  class ModelType(Enum):
60
  CHAT = "Chat"
 
83
 
84
  class ChatChunk(TypedDict):
85
  """Simplified response chunk for chat models."""
 
86
  response_delta: str
87
  reasoning_delta: str
88
 
89
+ class ChatGenerationResult:
90
+ """Chat generation result object"""
91
+ def __init__(self, chunk: ChatChunk|None = None):
92
+ self.reasoning = ""
93
+ self.response = ""
94
+ self.thinking = False
95
+ self.thinking_tag = ""
96
+ self.unprocessed = ""
97
+ self.native_reasoning = False
98
+ self.thinking_pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
99
+ if chunk:
100
+ self.add_chunk(chunk)
101
+
102
+ def add_chunk(self, chunk: ChatChunk) -> ChatChunk:
103
+ if chunk["reasoning_delta"]:
104
+ self.native_reasoning = True
105
+
106
+ # if native reasoning detection works, there's no need to worry about thinking tags
107
+ if self.native_reasoning:
108
+ processed_chunk = ChatChunk(response_delta=chunk["response_delta"], reasoning_delta=chunk["reasoning_delta"])
109
+ else:
110
+ # if the model outputs thinking tags, we ned to parse them manually as reasoning
111
+ processed_chunk = self._process_thinking_chunk(chunk)
112
+
113
+ self.reasoning += processed_chunk["reasoning_delta"]
114
+ self.response += processed_chunk["response_delta"]
115
+
116
+ return processed_chunk
117
+
118
+ def _process_thinking_chunk(self, chunk: ChatChunk) -> ChatChunk:
119
+ response_delta = self.unprocessed + chunk["response_delta"]
120
+ self.unprocessed = ""
121
+ return self._process_thinking_tags(response_delta, chunk["reasoning_delta"])
122
+
123
+ def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
124
+ if self.thinking:
125
+ close_pos = response.find(self.thinking_tag)
126
+ if close_pos != -1:
127
+ reasoning += response[:close_pos]
128
+ response = response[close_pos + len(self.thinking_tag):]
129
+ self.thinking = False
130
+ self.thinking_tag = ""
131
+ else:
132
+ if self._is_partial_closing_tag(response):
133
+ self.unprocessed = response
134
+ response = ""
135
+ else:
136
+ reasoning += response
137
+ response = ""
138
+ else:
139
+ for opening_tag, closing_tag in self.thinking_pairs:
140
+ if response.startswith(opening_tag):
141
+ response = response[len(opening_tag):]
142
+ self.thinking = True
143
+ self.thinking_tag = closing_tag
144
+
145
+ close_pos = response.find(closing_tag)
146
+ if close_pos != -1:
147
+ reasoning += response[:close_pos]
148
+ response = response[close_pos + len(closing_tag):]
149
+ self.thinking = False
150
+ self.thinking_tag = ""
151
+ else:
152
+ if self._is_partial_closing_tag(response):
153
+ self.unprocessed = response
154
+ response = ""
155
+ else:
156
+ reasoning += response
157
+ response = ""
158
+ break
159
+ elif len(response) < len(opening_tag) and self._is_partial_opening_tag(response, opening_tag):
160
+ self.unprocessed = response
161
+ response = ""
162
+ break
163
+
164
+ return ChatChunk(response_delta=response, reasoning_delta=reasoning)
165
+
166
+ def _is_partial_opening_tag(self, text: str, opening_tag: str) -> bool:
167
+ for i in range(1, len(opening_tag)):
168
+ if text == opening_tag[:i]:
169
+ return True
170
+ return False
171
+
172
+ def _is_partial_closing_tag(self, text: str) -> bool:
173
+ if not self.thinking_tag or not text:
174
+ return False
175
+ max_check = min(len(text), len(self.thinking_tag) - 1)
176
+ for i in range(1, max_check + 1):
177
+ if text.endswith(self.thinking_tag[:i]):
178
+ return True
179
+ return False
180
+
181
+ def output(self) -> ChatChunk:
182
+ response = self.response
183
+ reasoning = self.reasoning
184
+ if self.unprocessed:
185
+ if reasoning and not response:
186
+ reasoning += self.unprocessed
187
+ else:
188
+ response += self.unprocessed
189
+ return ChatChunk(response_delta=response, reasoning_delta=reasoning)
190
+
191
 
192
  rate_limiters: dict[str, RateLimiter] = {}
193
  api_keys_round_robin: dict[str, int] = {}
194
 
195
+
196
  def get_api_key(service: str) -> str:
197
  # get api key for the service
198
  key = (
 
219
  limiter.limits["output"] = output or 0
220
  return limiter
221
 
222
+
223
+ async def apply_rate_limiter(
224
+ model_config: ModelConfig | None,
225
+ input_text: str,
226
+ rate_limiter_callback: (
227
+ Callable[[str, str, int, int], Awaitable[bool]] | None
228
+ ) = None,
229
+ ):
230
  if not model_config:
231
  return
232
  limiter = get_rate_limiter(
 
241
  await limiter.wait(rate_limiter_callback)
242
  return limiter
243
 
244
+
245
+ def apply_rate_limiter_sync(
246
+ model_config: ModelConfig | None,
247
+ input_text: str,
248
+ rate_limiter_callback: (
249
+ Callable[[str, str, int, int], Awaitable[bool]] | None
250
+ ) = None,
251
+ ):
252
  if not model_config:
253
  return
254
  import asyncio, nest_asyncio
255
+
256
  nest_asyncio.apply()
257
+ return asyncio.run(
258
+ apply_rate_limiter(model_config, input_text, rate_limiter_callback)
259
+ )
260
 
261
 
262
  class LiteLLMChatWrapper(SimpleChatModel):
263
  model_name: str
264
  provider: str
265
  kwargs: dict = {}
266
+
267
  class Config:
268
  arbitrary_types_allowed = True
269
  extra = "allow" # Allow extra attributes
270
  validate_assignment = False # Don't validate on assignment
271
 
272
+ def __init__(
273
+ self,
274
+ model: str,
275
+ provider: str,
276
+ model_config: Optional[ModelConfig] = None,
277
+ **kwargs: Any,
278
+ ):
279
  model_value = f"{provider}/{model}"
280
  super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore
281
  # Set A0 model config as instance attribute after parent init
 
284
  @property
285
  def _llm_type(self) -> str:
286
  return "litellm-chat"
287
+
288
  def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]:
289
  result = []
290
  # Map LangChain message types to LiteLLM roles
 
341
  **kwargs: Any,
342
  ) -> str:
343
  import asyncio
344
+
345
  msgs = self._convert_messages(messages)
346
+
347
  # Apply rate limiting if configured
348
  apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
349
+
350
  # Call the model
351
  resp = completion(
352
  model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs}
 
354
 
355
  # Parse output
356
  parsed = _parse_chunk(resp)
357
+ output = ChatGenerationResult(parsed).output()
358
+ return output["response_delta"]
359
 
360
  def _stream(
361
  self,
 
365
  **kwargs: Any,
366
  ) -> Iterator[ChatGenerationChunk]:
367
  import asyncio
368
+
369
  msgs = self._convert_messages(messages)
370
+
371
  # Apply rate limiting if configured
372
  apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
373
+
374
+ result = ChatGenerationResult()
375
+
376
  for chunk in completion(
377
  model=self.model_name,
378
  messages=msgs,
 
380
  stop=stop,
381
  **{**self.kwargs, **kwargs},
382
  ):
383
+ # parse chunk
384
+ parsed = _parse_chunk(chunk) # chunk parsing
385
+ output = result.add_chunk(parsed) # chunk processing
386
+
387
  # Only yield chunks with non-None content
388
+ if output["response_delta"]:
389
  yield ChatGenerationChunk(
390
+ message=AIMessageChunk(content=output["response_delta"])
391
  )
392
 
393
  async def _astream(
 
398
  **kwargs: Any,
399
  ) -> AsyncIterator[ChatGenerationChunk]:
400
  msgs = self._convert_messages(messages)
401
+
402
  # Apply rate limiting if configured
403
  await apply_rate_limiter(self.a0_model_conf, str(msgs))
404
+
405
+ result = ChatGenerationResult()
406
+
407
  response = await acompletion(
408
  model=self.model_name,
409
  messages=msgs,
 
412
  **{**self.kwargs, **kwargs},
413
  )
414
  async for chunk in response: # type: ignore
415
+ # parse chunk
416
+ parsed = _parse_chunk(chunk) # chunk parsing
417
+ output = result.add_chunk(parsed) # chunk processing
418
+
419
  # Only yield chunks with non-None content
420
+ if output["response_delta"]:
421
  yield ChatGenerationChunk(
422
+ message=AIMessageChunk(content=output["response_delta"])
423
  )
424
 
425
  async def unified_call(
 
430
  response_callback: Callable[[str, str], Awaitable[None]] | None = None,
431
  reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
432
  tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
433
+ rate_limiter_callback: (
434
+ Callable[[str, str, int, int], Awaitable[bool]] | None
435
+ ) = None,
436
  **kwargs: Any,
437
  ) -> Tuple[str, str]:
438
 
 
450
  msgs_conv = self._convert_messages(messages)
451
 
452
  # Apply rate limiting if configured
453
+ limiter = await apply_rate_limiter(
454
+ self.a0_model_conf, str(msgs_conv), rate_limiter_callback
455
+ )
456
 
457
  # call model
458
  _completion = await acompletion(
 
463
  )
464
 
465
  # results
466
+ result = ChatGenerationResult()
 
467
 
468
  # iterate over chunks
469
  async for chunk in _completion: # type: ignore
470
+ # parse chunk
471
  parsed = _parse_chunk(chunk)
472
+ output = result.add_chunk(parsed)
473
+
474
  # collect reasoning delta and call callbacks
475
+ if output["reasoning_delta"]:
 
476
  if reasoning_callback:
477
+ await reasoning_callback(output["reasoning_delta"], result.reasoning)
478
  if tokens_callback:
479
  await tokens_callback(
480
+ output["reasoning_delta"],
481
+ approximate_tokens(output["reasoning_delta"]),
482
  )
483
  # Add output tokens to rate limiter if configured
484
  if limiter:
485
+ limiter.add(output=approximate_tokens(output["reasoning_delta"]))
486
  # collect response delta and call callbacks
487
+ if output["response_delta"]:
 
488
  if response_callback:
489
+ await response_callback(output["response_delta"], result.response)
490
  if tokens_callback:
491
  await tokens_callback(
492
+ output["response_delta"],
493
+ approximate_tokens(output["response_delta"]),
494
  )
495
  # Add output tokens to rate limiter if configured
496
  if limiter:
497
+ limiter.add(output=approximate_tokens(output["response_delta"]))
498
 
499
  # return complete results
500
+ return result.response, result.reasoning
501
 
502
 
503
  class BrowserCompatibleChatWrapper(LiteLLMChatWrapper):
 
540
  kwargs: dict = {}
541
  a0_model_conf: Optional[ModelConfig] = None
542
 
543
+ def __init__(
544
+ self,
545
+ model: str,
546
+ provider: str,
547
+ model_config: Optional[ModelConfig] = None,
548
+ **kwargs: Any,
549
+ ):
550
  self.model_name = f"{provider}/{model}" if provider != "openai" else model
551
  self.kwargs = kwargs
552
  self.a0_model_conf = model_config
553
+
554
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
555
  # Apply rate limiting if configured
556
  apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
557
+
558
  resp = embedding(model=self.model_name, input=texts, **self.kwargs)
559
  return [
560
  item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
 
564
  def embed_query(self, text: str) -> List[float]:
565
  # Apply rate limiting if configured
566
  apply_rate_limiter_sync(self.a0_model_conf, text)
567
+
568
  resp = embedding(model=self.model_name, input=[text], **self.kwargs)
569
  item = resp.data[0] # type: ignore
570
  return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
 
573
  class LocalSentenceTransformerWrapper(Embeddings):
574
  """Local wrapper for sentence-transformers models to avoid HuggingFace API calls"""
575
 
576
+ def __init__(
577
+ self,
578
+ provider: str,
579
+ model: str,
580
+ model_config: Optional[ModelConfig] = None,
581
+ **kwargs: Any,
582
+ ):
583
  # Clean common user-input mistakes
584
  model = model.strip().strip('"').strip("'")
585
 
 
601
  self.model = SentenceTransformer(model, **st_kwargs)
602
  self.model_name = model
603
  self.a0_model_conf = model_config
604
+
605
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
606
  # Apply rate limiting if configured
607
  apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
608
+
609
  embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore
610
  return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore
611
 
612
  def embed_query(self, text: str) -> List[float]:
613
  # Apply rate limiting if configured
614
  apply_rate_limiter_sync(self.a0_model_conf, text)
615
+
616
  embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore
617
  result = (
618
  embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0]
 
637
  provider_name, model_name, kwargs = _adjust_call_args(
638
  provider_name, model_name, kwargs
639
  )
640
+ return cls(
641
+ provider=provider_name, model=model_name, model_config=model_config, **kwargs
642
+ )
643
 
644
 
645
+ def _get_litellm_embedding(
646
+ model_name: str,
647
+ provider_name: str,
648
+ model_config: Optional[ModelConfig] = None,
649
+ **kwargs: Any,
650
+ ):
651
  # Check if this is a local sentence-transformers model
652
  if provider_name == "huggingface" and model_name.startswith(
653
  "sentence-transformers/"
 
657
  provider_name, model_name, kwargs
658
  )
659
  return LocalSentenceTransformerWrapper(
660
+ provider=provider_name,
661
+ model=model_name,
662
+ model_config=model_config,
663
+ **kwargs,
664
  )
665
 
666
  # use api key from kwargs or env
 
673
  provider_name, model_name, kwargs = _adjust_call_args(
674
  provider_name, model_name, kwargs
675
  )
676
+ return LiteLLMEmbeddingWrapper(
677
+ model=model_name, provider=provider_name, model_config=model_config, **kwargs
678
+ )
679
 
680
 
681
  def _parse_chunk(chunk: Any) -> ChatChunk:
 
697
  if isinstance(delta, dict)
698
  else getattr(delta, "reasoning_content", "")
699
  )
700
+
701
  return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta)
702
 
703
 
704
+
705
  def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict):
706
  # for openrouter add app reference
707
  if provider_name == "openrouter":
 
765
  return provider_name, kwargs
766
 
767
 
768
+ def get_chat_model(
769
+ provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any
770
+ ) -> LiteLLMChatWrapper:
771
  orig = provider.lower()
772
  provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs)
773
+ return _get_litellm_chat(
774
+ LiteLLMChatWrapper, name, provider_name, model_config, **kwargs
775
+ )
776
 
777
 
778
  def get_browser_model(
tests/chunk_parser_test.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4
+ import models
5
+
6
+ ex1 = "<think>reasoning goes here</think>response goes here"
7
+ ex2 = "<think>reasoning goes here</thi"
8
+
9
+
10
+ def test_example(example: str):
11
+ res = models.ChatGenerationResult()
12
+ for i in range(len(example)):
13
+ char = example[i]
14
+ chunk = res.add_chunk({"response_delta": char, "reasoning_delta": ""})
15
+ print(i, ":", chunk)
16
+
17
+ print("output", res.output())
18
+
19
+
20
+ if __name__ == "__main__":
21
+ # test_example(ex1)
22
+ test_example(ex2)
23
+