|
|
import io |
|
|
import yaml |
|
|
import inspect |
|
|
import asyncio |
|
|
import base64 |
|
|
from abc import ABC, abstractmethod |
|
|
from pydantic import Field |
|
|
from typing import Union, Optional, Type, Callable, List, Any, Dict |
|
|
|
|
|
from ..core.parser import Parser |
|
|
from .model_configs import LLMConfig |
|
|
from ..core.module_utils import ( |
|
|
parse_json_from_text, |
|
|
get_type_name, |
|
|
parse_xml_from_text, |
|
|
parse_data_from_text |
|
|
) |
|
|
|
|
|
|
|
|
PARSER_VALID_MODE = ["str", "json", "xml", "title", "custom"] |
|
|
|
|
|
class LLMOutputParser(Parser): |
|
|
"""A basic parser for LLM-generated content. |
|
|
|
|
|
This parser stores the raw text generated by an LLM in the `.content` attribute |
|
|
and provides methods to extract structured data from this text using different |
|
|
parsing strategies. |
|
|
|
|
|
Attributes: |
|
|
content: The raw text generated by the LLM. |
|
|
""" |
|
|
content: str = Field(default=None, exclude=True, description="the text generated by LLM") |
|
|
|
|
|
@classmethod |
|
|
def get_attrs(cls, return_type: bool = False) -> List[Union[str, tuple]]: |
|
|
"""Returns the attributes of the LLMOutputParser class. |
|
|
|
|
|
Excludes ["class_name", "content"] by default. |
|
|
|
|
|
Args: |
|
|
return_type: Whether to return the type of the attributes along with their names. |
|
|
|
|
|
Returns: |
|
|
If `return_type` is True, returns a list of tuples where each tuple contains |
|
|
the attribute name and its type. Otherwise, returns a list of attribute names. |
|
|
""" |
|
|
attrs = [] |
|
|
exclude_attrs = ["class_name", "content"] |
|
|
for field, field_info in cls.model_fields.items(): |
|
|
if field not in exclude_attrs: |
|
|
if return_type: |
|
|
field_type = get_type_name(field_info.annotation) |
|
|
attrs.append((field, field_type)) |
|
|
else: |
|
|
attrs.append(field) |
|
|
return attrs |
|
|
|
|
|
@classmethod |
|
|
def get_attr_descriptions(cls) -> dict: |
|
|
"""Returns the attributes and their descriptions. |
|
|
|
|
|
Returns: |
|
|
A dictionary mapping attribute names to their descriptions. |
|
|
""" |
|
|
attrs = cls.get_attrs() |
|
|
results = {} |
|
|
for field_name, field_info in cls.model_fields.items(): |
|
|
if field_name not in attrs: |
|
|
continue |
|
|
field_desc = field_info.description if field_info.description is not None else "None" |
|
|
results[field_name] = field_desc |
|
|
return results |
|
|
|
|
|
@classmethod |
|
|
def get_content_data(cls, content: str, parse_mode: str = "json", parse_func: Optional[Callable] = None, **kwargs) -> dict: |
|
|
"""Parses LLM-generated content into a dictionary. |
|
|
|
|
|
This method takes content from an LLM response and converts it to a structured |
|
|
dictionary based on the specified parsing mode. |
|
|
|
|
|
Args: |
|
|
content: The content to parse. |
|
|
parse_mode: The mode to parse the content. Must be one of: |
|
|
- 'str': Assigns the raw text content to all attributes of the parser. |
|
|
- 'json': Extracts and parses JSON objects from LLM output. It will return a dictionary parsed from the first valid JSON string. |
|
|
- 'xml': Parses content using XML tags. It will return a dictionary parsed from the XML tags. |
|
|
- 'title': Parses content with Markdown-style headings. |
|
|
- 'custom': Uses custom parsing logic. Requires providing `parse_func` parameter as a custom parsing function. |
|
|
parse_func: The function to parse the content, only valid when parse_mode is 'custom'. |
|
|
**kwargs (Any): Additional arguments passed to the parsing function. |
|
|
|
|
|
Returns: |
|
|
The parsed content as a dictionary. |
|
|
|
|
|
Raises: |
|
|
ValueError: If parse_mode is invalid or if parse_func is not provided when parse_mode is 'custom'. |
|
|
""" |
|
|
attrs = cls.get_attrs() |
|
|
if len(attrs) <= 0: |
|
|
return {} |
|
|
|
|
|
if parse_mode == "str": |
|
|
parse_func = cls._parse_str_content |
|
|
elif parse_mode == "json": |
|
|
parse_func = cls._parse_json_content |
|
|
elif parse_mode == "xml": |
|
|
parse_func = cls._parse_xml_content |
|
|
elif parse_mode == "title": |
|
|
parse_func = cls._parse_title_content |
|
|
elif parse_mode == "custom": |
|
|
if parse_func is None: |
|
|
raise ValueError("`parse_func` must be provided when `parse_mode` is 'custom'.") |
|
|
|
|
|
signature = inspect.signature(parse_func) |
|
|
if "content" not in signature.parameters: |
|
|
raise ValueError("`parse_func` must have an input argument `content`.") |
|
|
|
|
|
func_args = {} |
|
|
func_args["content"] = content |
|
|
for param_name, param in signature.parameters.items(): |
|
|
if param_name == "content": |
|
|
continue |
|
|
if param_name in kwargs: |
|
|
func_args[param_name] = kwargs[param_name] |
|
|
data = parse_func(**func_args) |
|
|
if not isinstance(data, dict): |
|
|
raise ValueError(f"The output of `parse_func` must be a dictionary, but found {type(data)}.") |
|
|
return data |
|
|
else: |
|
|
raise ValueError(f"Invalid value '{parse_mode}' detected for `parse_mode`. Available choices: {PARSER_VALID_MODE}") |
|
|
data = parse_func(content=content, **kwargs) |
|
|
return data |
|
|
|
|
|
@classmethod |
|
|
def _parse_str_content(cls, content: str, **kwargs) -> dict: |
|
|
"""Parses content by setting all attributes to the raw content. |
|
|
|
|
|
Args: |
|
|
content: The content to parse. |
|
|
**kwargs: Additional arguments (not used). |
|
|
|
|
|
Returns: |
|
|
A dictionary mapping all attributes to the raw content. |
|
|
""" |
|
|
|
|
|
attrs = cls.get_attrs() |
|
|
return {attr: content for attr in attrs} |
|
|
|
|
|
@classmethod |
|
|
def _parse_json_content(cls, content: str, **kwargs) -> dict: |
|
|
"""Parses content by extracting and parsing a JSON object. |
|
|
If the content contains multiple JSON objects, only the first one will be used. |
|
|
|
|
|
Args: |
|
|
content: The content containing a JSON object. |
|
|
**kwargs: Additional arguments (not used). |
|
|
|
|
|
Returns: |
|
|
The parsed JSON as a dictionary. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the content doesn't contain a valid JSON object. |
|
|
""" |
|
|
extracted_json_list = parse_json_from_text(content) |
|
|
if len(extracted_json_list) > 0: |
|
|
json_str = extracted_json_list[0] |
|
|
try: |
|
|
data = yaml.safe_load(json_str) |
|
|
if not isinstance(data, dict): |
|
|
if isinstance(data, list): |
|
|
|
|
|
attrs = cls.get_attrs() |
|
|
if len(attrs) == 1: |
|
|
|
|
|
return {attrs[0]: data} |
|
|
else: |
|
|
|
|
|
raise ValueError("The generated content is a list of JSON strings, but the attribute name for the list is not specified. You should instruct the LLM to specify the attribute name for the list.") |
|
|
else: |
|
|
raise ValueError(f"The generated content is not a valid JSON string:\n{json_str}") |
|
|
except Exception: |
|
|
raise ValueError(f"The generated content is not a valid JSON string:\n{json_str}") |
|
|
else: |
|
|
raise ValueError(f"The following generated content does not contain JSON string!\n{content}") |
|
|
|
|
|
return data |
|
|
|
|
|
@classmethod |
|
|
def _parse_xml_content(cls, content: str, **kwargs) -> dict: |
|
|
"""Parses content by extracting values from XML tags. |
|
|
|
|
|
Each attribute of the parser is expected to be enclosed in XML tags |
|
|
with the attribute name as the tag name. |
|
|
|
|
|
Args: |
|
|
content: The content containing XML tags. |
|
|
**kwargs: Additional arguments (not used). |
|
|
|
|
|
Returns: |
|
|
A dictionary mapping attributes to their extracted values. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the content is missing expected XML tags or if the |
|
|
extracted values can't be converted to the expected types. |
|
|
""" |
|
|
attrs_with_types: List[tuple] = cls.get_attrs(return_type=True) |
|
|
data = {} |
|
|
for attr, attr_type in attrs_with_types: |
|
|
attr_raw_value_list = parse_xml_from_text(text=content, label=attr) |
|
|
if len(attr_raw_value_list) > 0: |
|
|
attr_raw_value = attr_raw_value_list[0] |
|
|
try: |
|
|
attr_value = parse_data_from_text(text=attr_raw_value, datatype=attr_type) |
|
|
except Exception: |
|
|
raise ValueError(f"Cannot parse text: {attr_raw_value} into {attr_type} data!") |
|
|
else: |
|
|
raise ValueError(f"The following generated content does not contain xml label <{attr}>xxx</{attr}>!\n{content}") |
|
|
data[attr] = attr_value |
|
|
return data |
|
|
|
|
|
@classmethod |
|
|
def _parse_title_content(cls, content: str, title_format: str = "## {title}", **kwargs) -> dict: |
|
|
"""Parses content with markdown-style titles. |
|
|
|
|
|
Extracts sections from content that are divided by titles following |
|
|
the specified format described in `title_format`. The default format is "## {title}". |
|
|
For example: |
|
|
``` |
|
|
## title1 |
|
|
content1 |
|
|
## title2 |
|
|
content2 |
|
|
``` |
|
|
This content will be parsed into: |
|
|
``` |
|
|
{ |
|
|
"title1": "content1", |
|
|
"title2": "content2" |
|
|
} |
|
|
``` |
|
|
Args: |
|
|
content: The content with title-divided sections. |
|
|
title_format: The format of the titles, default is "## {title}". |
|
|
**kwargs: Additional arguments (not used). |
|
|
|
|
|
Returns: |
|
|
A dictionary mapping title names to their section contents. |
|
|
""" |
|
|
attrs: List[str] = cls.get_attrs() |
|
|
if not attrs: |
|
|
return {} |
|
|
|
|
|
output_titles = [title_format.format(title=attr) for attr in attrs] |
|
|
|
|
|
def is_output_title(text: str): |
|
|
for title in output_titles: |
|
|
if text.strip().lower().startswith(title.lower()): |
|
|
return True, title |
|
|
return False, None |
|
|
|
|
|
data = {} |
|
|
current_output_name: str = None |
|
|
current_output_content: list = None |
|
|
for line in content.split("\n"): |
|
|
is_title, title = is_output_title(line) |
|
|
if is_title: |
|
|
if current_output_name is not None and current_output_content is not None: |
|
|
data[current_output_name] = "\n".join(current_output_content) |
|
|
current_output_content = [] |
|
|
current_output_name = title.replace("#", "").strip() |
|
|
output_titles.remove(title) |
|
|
else: |
|
|
if current_output_content is not None: |
|
|
current_output_content.append(line) |
|
|
if current_output_name is not None and current_output_content is not None: |
|
|
data[current_output_name] = "\n".join(current_output_content) |
|
|
return data |
|
|
|
|
|
@classmethod |
|
|
def parse(cls, content: str, parse_mode: str = "json", parse_func: Optional[Callable] = None, **kwargs) -> "LLMOutputParser": |
|
|
"""Parses LLM-generated text into a structured parser instance. |
|
|
|
|
|
This is the main method for creating parser instances from LLM output. |
|
|
|
|
|
Args: |
|
|
content: The text generated by the LLM. |
|
|
parse_mode: The mode to parse the content, must be one of: |
|
|
- 'str': Assigns the raw text content to all attributes of the parser. |
|
|
- 'json': Extracts and parses JSON objects from LLM output. Uses the first valid JSON string to create an instance of LLMOutputParser. |
|
|
- 'xml': Parses content using XML tags. Uses the XML tags to create an instance of LLMOutputParser. |
|
|
- 'title': Parses content with Markdown-style headings. Uses the Markdown-style headings to create an instance of LLMOutputParser. The default title format is "## {title}", you can change it by providing `title_format` parameter, which should be a string that contains `{title}` placeholder. |
|
|
- 'custom': Uses custom parsing logic. Requires providing `parse_func` parameter as a custom parsing function. The `parse_func` must have a parameter named `content` and return a dictionary where the keys are the attribute names and the values are the parsed data. |
|
|
parse_func: The function to parse the content, only valid when `parse_mode` is 'custom'. |
|
|
**kwargs (Any): Additional arguments passed to parsing functions, such as: |
|
|
- `title_format` for `parse_mode="title"`. |
|
|
|
|
|
Returns: |
|
|
An instance of LLMOutputParser containing the parsed data. |
|
|
|
|
|
Raises: |
|
|
ValueError: If parse_mode is invalid or if content is not a string. |
|
|
""" |
|
|
if parse_mode not in PARSER_VALID_MODE: |
|
|
raise ValueError(f"'{parse_mode}' is an invalid value for `parse_mode`. Available choices: {PARSER_VALID_MODE}.") |
|
|
if not isinstance(content, str): |
|
|
raise ValueError(f"The input to {cls.__name__}.parse should be a str, but found {type(content)}.") |
|
|
data = cls.get_content_data(content=content, parse_mode=parse_mode, parse_func=parse_func, **kwargs) |
|
|
data.update({"content": content}) |
|
|
parser = cls.from_dict(data, **kwargs) |
|
|
|
|
|
return parser |
|
|
|
|
|
def __str__(self) -> str: |
|
|
""" |
|
|
Returns a string representation of the parser. |
|
|
""" |
|
|
return self.to_str() |
|
|
|
|
|
def to_str(self, **kwargs) -> str: |
|
|
""" |
|
|
Converts the parser to a string. |
|
|
""" |
|
|
return self.content |
|
|
|
|
|
def get_structured_data(self) -> dict: |
|
|
"""Extracts structured data from the parser. |
|
|
|
|
|
Returns: |
|
|
A dictionary containing only the defined attributes and their values, |
|
|
excluding metadata like class_name. |
|
|
""" |
|
|
attrs = type(self).get_attrs() |
|
|
data = self.to_dict(ignore=["class_name"]) |
|
|
|
|
|
structured_data = {key: value for key, value in data.items() if key in attrs} |
|
|
return structured_data |
|
|
|
|
|
|
|
|
def _is_multimodal_content(content: Any) -> bool: |
|
|
"""Check if content contains multimodal objects (TextChunk, ImageChunk, etc.).""" |
|
|
try: |
|
|
from ..rag.schema import TextChunk, ImageChunk |
|
|
|
|
|
|
|
|
if isinstance(content, list): |
|
|
return any(isinstance(item, (TextChunk, ImageChunk)) for item in content) |
|
|
elif isinstance(content, (TextChunk, ImageChunk)): |
|
|
return True |
|
|
|
|
|
return False |
|
|
except ImportError: |
|
|
return False |
|
|
|
|
|
|
|
|
def _process_multimodal_content(content: List[Any], model_type: str = "openai") -> List[Dict[str, Any]]: |
|
|
"""Convert multimodal content (TextChunk, ImageChunk) to model-specific message format.""" |
|
|
try: |
|
|
from ..rag.schema import TextChunk, ImageChunk |
|
|
except ImportError: |
|
|
raise ImportError("Cannot import TextChunk/ImageChunk from rag.schema for multimodal processing") |
|
|
|
|
|
processed_content = [] |
|
|
|
|
|
for item in content: |
|
|
if isinstance(item, TextChunk): |
|
|
processed_content.append({ |
|
|
"type": "text", |
|
|
"text": item.text |
|
|
}) |
|
|
elif isinstance(item, ImageChunk): |
|
|
if model_type.lower() in ["openai", "openrouter", "litellm"]: |
|
|
|
|
|
image_data = _get_image_data_url(item) |
|
|
processed_content.append({ |
|
|
"type": "image_url", |
|
|
"image_url": {"url": image_data} |
|
|
}) |
|
|
else: |
|
|
|
|
|
processed_content.append({ |
|
|
"type": "image", |
|
|
"image_path": item.image_path, |
|
|
"image_mimetype": item.image_mimetype |
|
|
}) |
|
|
else: |
|
|
|
|
|
if isinstance(item, str): |
|
|
processed_content.append({ |
|
|
"type": "text", |
|
|
"text": item |
|
|
}) |
|
|
else: |
|
|
|
|
|
processed_content.append({ |
|
|
"type": "text", |
|
|
"text": str(item) |
|
|
}) |
|
|
|
|
|
return processed_content |
|
|
|
|
|
|
|
|
def _get_image_data_url(image_chunk) -> str: |
|
|
"""Convert ImageChunk to data URL format for model consumption.""" |
|
|
|
|
|
try: |
|
|
|
|
|
image = image_chunk.get_image() |
|
|
if image is None: |
|
|
raise ValueError(f"Could not load image from path: {image_chunk.image_path}") |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
|
|
|
format_name = "PNG" |
|
|
if image_chunk.image_mimetype: |
|
|
format_name = image_chunk.image_mimetype.split('/')[-1].upper() |
|
|
if format_name not in ['PNG', 'JPEG', 'JPG', 'GIF', 'WEBP']: |
|
|
format_name = "PNG" |
|
|
|
|
|
image.save(buffer, format=format_name) |
|
|
image_data = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
|
|
|
|
mime_type = image_chunk.image_mimetype or f"image/{format_name.lower()}" |
|
|
return f"data:{mime_type};base64,{image_data}" |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to convert image to data URL: {str(e)}") |
|
|
|
|
|
|
|
|
class BaseLLM(ABC): |
|
|
"""Abstract base class for Large Language Model implementations. |
|
|
|
|
|
This class defines the interface that all LLM implementations must follow, |
|
|
providing methods for generating text, formatting messages, and parsing output. |
|
|
|
|
|
Attributes: |
|
|
config: Configuration for the LLM. |
|
|
kwargs: Additional keyword arguments provided during initialization. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LLMConfig, **kwargs): |
|
|
"""Initializes the LLM with configuration. |
|
|
|
|
|
Args: |
|
|
config: Configuration object for the LLM. |
|
|
**kwargs (Any): Additional keyword arguments. |
|
|
""" |
|
|
self.config = config |
|
|
self.kwargs = kwargs |
|
|
self.init_model() |
|
|
|
|
|
@abstractmethod |
|
|
def init_model(self): |
|
|
"""Initializes the underlying model. |
|
|
|
|
|
This method should be implemented by subclasses to set up the actual LLM. |
|
|
""" |
|
|
pass |
|
|
|
|
|
def __deepcopy__(self, memo) -> "BaseLLM": |
|
|
"""Handles deep copying of the LLM instance. |
|
|
|
|
|
Returns the same instance when deepcopy is called, as LLM instances |
|
|
often cannot be meaningfully deep-copied. |
|
|
|
|
|
Args: |
|
|
memo (Dict[int, Any]): Memo dictionary used by the deepcopy process. |
|
|
|
|
|
Returns: |
|
|
The same LLM instance. |
|
|
""" |
|
|
|
|
|
memo[id(self)] = self |
|
|
return self |
|
|
|
|
|
@abstractmethod |
|
|
def formulate_messages(self, prompts: List[str], system_messages: Optional[List[str]] = None) -> List[List[dict]]: |
|
|
"""Converts input prompts into the chat format compatible with different LLMs. |
|
|
|
|
|
Args: |
|
|
prompts: A list of user prompts that need to be converted. |
|
|
system_messages: An optional list of system messages that provide instructions or context to the model. |
|
|
|
|
|
Returns: |
|
|
A list of message lists, where each inner list contains messages in the chat format required by LLMs. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def single_generate(self, messages: List[dict], **kwargs) -> str: |
|
|
"""Generates LLM output for a single set of messages. |
|
|
|
|
|
Args: |
|
|
messages: The input messages to the LLM in chat format. |
|
|
**kwargs (Any): Additional keyword arguments for generation settings. |
|
|
|
|
|
Returns: |
|
|
The generated output text from the LLM. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]: |
|
|
"""Generates outputs for a batch of message sets. |
|
|
|
|
|
Args: |
|
|
batch_messages: A list of message lists, where each inner list contains messages for a single generation. |
|
|
**kwargs (Any): Additional keyword arguments for generation settings. |
|
|
|
|
|
Returns: |
|
|
A list of generated outputs from the LLM, one for each input message set. |
|
|
""" |
|
|
pass |
|
|
|
|
|
async def single_generate_async(self, messages: List[dict], **kwargs) -> str: |
|
|
"""Asynchronously generates LLM output for a single set of messages. |
|
|
|
|
|
This default implementation wraps the synchronous method in an async executor. |
|
|
Subclasses should override this for true async implementation if supported. |
|
|
|
|
|
Args: |
|
|
messages: The input messages to the LLM in chat format. |
|
|
**kwargs (Any): Additional keyword arguments for generation settings. |
|
|
|
|
|
Returns: |
|
|
The generated output text from the LLM. |
|
|
""" |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
result = await loop.run_in_executor(None, self.single_generate, messages, **kwargs) |
|
|
return result |
|
|
|
|
|
async def batch_generate_async(self, batch_messages: List[List[dict]], **kwargs) -> List[str]: |
|
|
"""Asynchronously generates outputs for a batch of message sets. |
|
|
|
|
|
This default implementation runs each generation as a separate async task. |
|
|
Subclasses should override this for more efficient async batching if supported. |
|
|
|
|
|
Args: |
|
|
batch_messages: A list of message lists, where each inner list contains messages for a single generation. |
|
|
**kwargs (Any): Additional keyword arguments for generation settings. |
|
|
|
|
|
Returns: |
|
|
A list of generated outputs from the LLM, one for each input message set. |
|
|
""" |
|
|
|
|
|
tasks = [self.single_generate_async(messages, **kwargs) for messages in batch_messages] |
|
|
return await asyncio.gather(*tasks) |
|
|
|
|
|
def parse_generated_text(self, text: str, parser: Optional[Type[LLMOutputParser]]=None, parse_mode: Optional[str] = "json", parse_func: Optional[Callable] = None, **kwargs) -> LLMOutputParser: |
|
|
"""Parses generated text into a structured output using a parser. |
|
|
|
|
|
Args: |
|
|
text: The text generated by the LLM. |
|
|
parser: An LLMOutputParser class to use for parsing. If None, the default LLMOutputParser is used. |
|
|
parse_mode: The mode to use for parsing, must be the `parse_mode` supported by the `parser`. |
|
|
**kwargs (Any): Additional arguments passed to the parser. |
|
|
|
|
|
Returns: |
|
|
An LLMOutputParser instance containing the parsed data. |
|
|
""" |
|
|
if not parser: |
|
|
parser = LLMOutputParser |
|
|
return parser.parse(text, parse_mode=parse_mode, parse_func=parse_func) |
|
|
|
|
|
def parse_generated_texts(self, texts: List[str], parser: Optional[Type[LLMOutputParser]]=None, parse_mode: Optional[str] = "json", parse_func: Optional[Callable] = None, **kwargs) -> List[LLMOutputParser]: |
|
|
"""Parses multiple generated texts into structured outputs. |
|
|
|
|
|
Args: |
|
|
texts: A list of texts generated by the LLM. |
|
|
parser: An LLMOutputParser class to use for parsing. |
|
|
parse_mode: The mode to use for parsing, must be the `parse_mode` supported by the `parser`. |
|
|
**kwargs (Any): Additional arguments passed to the parser. |
|
|
|
|
|
Returns: |
|
|
A list of LLMOutputParser instances containing the parsed data. |
|
|
""" |
|
|
parsed_results = [self.parse_generated_text(text=text, parser=parser, parse_mode=parse_mode, parse_func=parse_func, **kwargs) for text in texts] |
|
|
return parsed_results |
|
|
|
|
|
def _prepare_messages( |
|
|
self, |
|
|
prompt: Optional[Union[str, List[str]]] = None, |
|
|
system_message: Optional[Union[str, List[str]]] = None, |
|
|
messages: Optional[Union[List[dict],List[List[dict]]]] = None, |
|
|
) -> tuple[List[List[dict]], bool]: |
|
|
"""Prepares and validates input messages for generation. |
|
|
|
|
|
This internal method handles the various input formats (prompt strings, system messages, |
|
|
or pre-formatted message dictionaries) and converts them to a consistent format for generation. |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt(s) to the LLM. |
|
|
system_message: System message(s) for the LLM. |
|
|
messages: Chat message(s) for the LLM, already in the required format. |
|
|
|
|
|
Returns: |
|
|
A tuple containing: |
|
|
- prepared_messages: List of message lists ready for generation |
|
|
- is_single_generate: Boolean indicating if this is a single generation (vs. batch) |
|
|
|
|
|
Raises: |
|
|
ValueError: If neither prompt nor messages is provided, or if both are provided. |
|
|
TypeError: If the inputs have inconsistent types or formats. |
|
|
""" |
|
|
if not (prompt or messages): |
|
|
raise ValueError("Either 'prompt' or 'messages' must be provided.") |
|
|
if prompt and messages: |
|
|
raise ValueError("Both 'prompt' and 'messages' are provided. Please provide only one of them.") |
|
|
|
|
|
single_generate = False |
|
|
if messages is not None: |
|
|
if not messages: |
|
|
return [], False |
|
|
if isinstance(messages[0], dict): |
|
|
single_generate = True |
|
|
messages = [messages] |
|
|
|
|
|
|
|
|
processed_messages = self._process_messages_for_multimodal(messages) |
|
|
return processed_messages, single_generate |
|
|
|
|
|
if isinstance(prompt, str): |
|
|
single_generate = True |
|
|
prompt = [prompt] |
|
|
if system_message: |
|
|
if not isinstance(system_message, str): |
|
|
raise TypeError(f"'system_message' should be a string when passing a single prompt, but found {type(system_message)}.") |
|
|
system_message = [system_message] |
|
|
elif isinstance(prompt, list) and all(isinstance(p, str) for p in prompt): |
|
|
single_generate = False |
|
|
if not prompt: |
|
|
return [], False |
|
|
if system_message: |
|
|
if not isinstance(system_message, list) or len(prompt) != len(system_message): |
|
|
raise ValueError(f"'system_message' should be a list of string when passing multiple prompts and the number of prompts ({len(prompt)}) must match the number of system messages ({len(system_message)}).") |
|
|
else: |
|
|
raise ValueError(f"'prompt' must be a str or List[str], but found {type(prompt)}.") |
|
|
|
|
|
prepared_messages = self.formulate_messages(prompts=prompt, system_messages=system_message) |
|
|
return prepared_messages, single_generate |
|
|
|
|
|
def _process_messages_for_multimodal(self, messages: List[List[dict]]) -> List[List[dict]]: |
|
|
"""Process messages to handle multimodal content (TextChunk, ImageChunk).""" |
|
|
processed_messages = [] |
|
|
|
|
|
for message_list in messages: |
|
|
processed_message_list = [] |
|
|
|
|
|
for message in message_list: |
|
|
processed_message = message.copy() |
|
|
content = message.get("content") |
|
|
|
|
|
|
|
|
if _is_multimodal_content(content): |
|
|
|
|
|
llm_type = getattr(self.config, 'llm_type', 'openai') |
|
|
|
|
|
if llm_type.lower() in ["openaillm", "openai"]: |
|
|
model_type = "openai" |
|
|
elif llm_type.lower() in ["litellm"]: |
|
|
model_type = "litellm" |
|
|
elif llm_type.lower() in ["openrouter"]: |
|
|
model_type = "openrouter" |
|
|
else: |
|
|
model_type = "openai" |
|
|
|
|
|
from ..core.logging import logger |
|
|
logger.debug(f"Processing multimodal content: llm_type={llm_type}, model_type={model_type}") |
|
|
|
|
|
|
|
|
if isinstance(content, list): |
|
|
processed_message["content"] = _process_multimodal_content(content, model_type) |
|
|
else: |
|
|
|
|
|
processed_message["content"] = _process_multimodal_content([content], model_type) |
|
|
|
|
|
processed_message_list.append(processed_message) |
|
|
|
|
|
processed_messages.append(processed_message_list) |
|
|
|
|
|
return processed_messages |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: Optional[Union[str, List[str]]] = None, |
|
|
system_message: Optional[Union[str, List[str]]] = None, |
|
|
messages: Optional[Union[List[dict],List[List[dict]]]] = None, |
|
|
parser: Optional[Type[LLMOutputParser]] = None, |
|
|
parse_mode: Optional[str] = "json", |
|
|
parse_func: Optional[Callable] = None, |
|
|
**kwargs |
|
|
) -> Union[LLMOutputParser, List[LLMOutputParser]]: |
|
|
"""Generates LLM output(s) and parses the result(s). |
|
|
|
|
|
This is the main method for generating text with the LLM. It handles both |
|
|
single and batch generation, and automatically parses the outputs. |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt(s) to the LLM. |
|
|
system_message: System message(s) for the LLM. |
|
|
messages: Chat message(s) for the LLM, already in the required format (either `prompt` or `messages` must be provided). |
|
|
parser: Parser class to use for processing the output. |
|
|
parse_mode: The mode to use for parsing, must be the `parse_mode` supported by the `parser`. |
|
|
**kwargs (Any): Additional generation configuration parameters. |
|
|
|
|
|
Returns: |
|
|
For single generation: An LLMOutputParser instance. |
|
|
For batch generation: A list of LLMOutputParser instances. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the input format is invalid. |
|
|
|
|
|
Note: |
|
|
Either prompt or messages must be provided. If both or neither is provided, |
|
|
an error will be raised. |
|
|
""" |
|
|
prepared_messages, single_generate = self._prepare_messages(prompt, system_message, messages) |
|
|
if not prepared_messages: |
|
|
return [] |
|
|
|
|
|
generated_texts = self.batch_generate(batch_messages=prepared_messages, **kwargs) |
|
|
parsed_outputs = self.parse_generated_texts(texts=generated_texts, parser=parser, parse_mode=parse_mode, parse_func=parse_func, **kwargs) |
|
|
return parsed_outputs[0] if single_generate else parsed_outputs |
|
|
|
|
|
async def async_generate( |
|
|
self, |
|
|
prompt: Optional[Union[str, List[str]]] = None, |
|
|
system_message: Optional[Union[str, List[str]]] = None, |
|
|
messages: Optional[Union[List[dict],List[List[dict]]]] = None, |
|
|
parser: Optional[Type[LLMOutputParser]] = None, |
|
|
parse_mode: Optional[str] = "json", |
|
|
parse_func: Optional[Callable] = None, |
|
|
**kwargs |
|
|
) -> Union[LLMOutputParser, List[LLMOutputParser]]: |
|
|
"""Asynchronously generates LLM output(s) and parses the result(s). |
|
|
|
|
|
This is the async version of the generate method. It works identically but |
|
|
performs the generation asynchronously. |
|
|
""" |
|
|
prepared_messages, single_generate = self._prepare_messages(prompt, system_message, messages) |
|
|
if not prepared_messages: |
|
|
return [] |
|
|
|
|
|
generated_texts = await self.batch_generate_async(batch_messages=prepared_messages, **kwargs) |
|
|
parsed_outputs = self.parse_generated_texts(texts=generated_texts, parser=parser, parse_mode=parse_mode, parse_func=parse_func, **kwargs) |
|
|
return parsed_outputs[0] if single_generate else parsed_outputs |
|
|
|
|
|
|
|
|
__all__ = ["LLMConfig", "BaseLLM", "LLMOutputParser"] |
|
|
|
|
|
|