harvesthealth commited on
Commit
cbe459e
·
verified ·
1 Parent(s): f6eac93

Upload folder using huggingface_hub

Browse files
tinytroupe/openai_utils.py CHANGED
@@ -267,11 +267,24 @@ class OpenAIClient:
267
  # complete message
268
  logger.debug(f" --> Complete messages sent to LLM: {chat_api_params['messages']}")
269
 
270
- result_message = self.client.beta.chat.completions.parse(
271
- **chat_api_params
272
- )
273
-
274
- return result_message
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  else:
277
  logger.debug(f"Calling LLM model with these parameters: {logged_params}. Not showing 'messages' parameter.")
 
267
  # complete message
268
  logger.debug(f" --> Complete messages sent to LLM: {chat_api_params['messages']}")
269
 
270
+ try:
271
+ result_message = self.client.beta.chat.completions.parse(
272
+ **chat_api_params
273
+ )
274
+ return result_message
275
+ except Exception as e:
276
+ logger.warning(f"Error while parsing LLM response with .parse(): {e}. Falling back to .create().")
277
+ # Fallback to regular create if parse fails (e.g. due to messy JSON with <think> tags)
278
+ # We need to remove response_format if it's a Pydantic model for create()
279
+ # but wait, completions.create also supports response_format={"type": "json_object"}
280
+
281
+ # If it was a Pydantic model, we convert it to json_object for the fallback
282
+ if not isinstance(chat_api_params["response_format"], dict):
283
+ chat_api_params["response_format"] = {"type": "json_object"}
284
+
285
+ return self.client.chat.completions.create(
286
+ **chat_api_params
287
+ )
288
 
289
  else:
290
  logger.debug(f"Calling LLM model with these parameters: {logged_params}. Not showing 'messages' parameter.")
tinytroupe/utils/llm.py CHANGED
@@ -850,6 +850,10 @@ def extract_json(text: str) -> dict:
850
  try:
851
  logger.debug(f"Extracting JSON from text: {text}")
852
 
 
 
 
 
853
  # if it already is a dictionary or list, return it
854
  if isinstance(text, dict) or isinstance(text, list):
855
 
 
850
  try:
851
  logger.debug(f"Extracting JSON from text: {text}")
852
 
853
+ # Remove <think>...</think> blocks if they exist
854
+ if isinstance(text, str):
855
+ text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
856
+
857
  # if it already is a dictionary or list, return it
858
  if isinstance(text, dict) or isinstance(text, list):
859
 
tinytroupe/utils/validation.py CHANGED
@@ -4,6 +4,7 @@ import unicodedata
4
 
5
  from pydantic import ValidationError, BaseModel
6
  from tinytroupe.utils import logger
 
7
 
8
  ################################################################################
9
  # Validation
@@ -58,7 +59,7 @@ def to_pydantic_or_sanitized_dict(value: dict, model: BaseModel=None) -> dict:
58
  if model is not None and (isinstance(model, type) and issubclass(model, BaseModel)):
59
  # If a model is provided, try to validate the value against the model
60
  try:
61
- res = model.model_validate(sanitize_dict(json.loads(value['content'])))
62
  return res
63
  except ValidationError as e:
64
  logger.warning(f"Validation error: {e}")
 
4
 
5
  from pydantic import ValidationError, BaseModel
6
  from tinytroupe.utils import logger
7
+ from tinytroupe.utils.llm import extract_json
8
 
9
  ################################################################################
10
  # Validation
 
59
  if model is not None and (isinstance(model, type) and issubclass(model, BaseModel)):
60
  # If a model is provided, try to validate the value against the model
61
  try:
62
+ res = model.model_validate(sanitize_dict(extract_json(value['content'])))
63
  return res
64
  except ValidationError as e:
65
  logger.warning(f"Validation error: {e}")
utils/llm.py CHANGED
@@ -721,7 +721,7 @@ class LLMChat:
721
 
722
  def _request_list_of_dict_llm_message(self):
723
  return {"role": "user",
724
- "content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\{...\}, \{...\}, ...]`. This is critical for later processing."}
725
 
726
  def _coerce_to_list(self, llm_output:str):
727
  """
@@ -850,6 +850,10 @@ def extract_json(text: str) -> dict:
850
  try:
851
  logger.debug(f"Extracting JSON from text: {text}")
852
 
 
 
 
 
853
  # if it already is a dictionary or list, return it
854
  if isinstance(text, dict) or isinstance(text, list):
855
 
 
721
 
722
  def _request_list_of_dict_llm_message(self):
723
  return {"role": "user",
724
+ "content": "The `value` field you generate **must** be a list of dictionaries, specified as a JSON structure embedded in a string. For example, `[\\{...\\}, \\{...\\}, ...]`. This is critical for later processing."}
725
 
726
  def _coerce_to_list(self, llm_output:str):
727
  """
 
850
  try:
851
  logger.debug(f"Extracting JSON from text: {text}")
852
 
853
+ # Remove <think>...</think> blocks if they exist
854
+ if isinstance(text, str):
855
+ text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
856
+
857
  # if it already is a dictionary or list, return it
858
  if isinstance(text, dict) or isinstance(text, list):
859
 
utils/validation.py CHANGED
@@ -4,6 +4,7 @@ import unicodedata
4
 
5
  from pydantic import ValidationError, BaseModel
6
  from tinytroupe.utils import logger
 
7
 
8
  ################################################################################
9
  # Validation
@@ -58,7 +59,7 @@ def to_pydantic_or_sanitized_dict(value: dict, model: BaseModel=None) -> dict:
58
  if model is not None and (isinstance(model, type) and issubclass(model, BaseModel)):
59
  # If a model is provided, try to validate the value against the model
60
  try:
61
- res = model.model_validate(sanitize_dict(json.loads(value['content'])))
62
  return res
63
  except ValidationError as e:
64
  logger.warning(f"Validation error: {e}")
 
4
 
5
  from pydantic import ValidationError, BaseModel
6
  from tinytroupe.utils import logger
7
+ from tinytroupe.utils.llm import extract_json
8
 
9
  ################################################################################
10
  # Validation
 
59
  if model is not None and (isinstance(model, type) and issubclass(model, BaseModel)):
60
  # If a model is provided, try to validate the value against the model
61
  try:
62
+ res = model.model_validate(sanitize_dict(extract_json(value['content'])))
63
  return res
64
  except ValidationError as e:
65
  logger.warning(f"Validation error: {e}")