|
|
import enum |
|
|
|
|
|
from langchain_core.messages import BaseMessage |
|
|
from pydantic import BaseModel, field_validator, model_validator |
|
|
from typing_extensions import TypedDict |
|
|
|
|
|
from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES |
|
|
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI |
|
|
|
|
|
|
|
|
class File(TypedDict): |
|
|
"""File schema.""" |
|
|
|
|
|
path: str |
|
|
name: str |
|
|
type: str |
|
|
|
|
|
|
|
|
class ChatOutputResponse(BaseModel): |
|
|
"""Chat output response schema.""" |
|
|
|
|
|
message: str | list[str | dict] |
|
|
sender: str | None = MESSAGE_SENDER_AI |
|
|
sender_name: str | None = MESSAGE_SENDER_NAME_AI |
|
|
session_id: str | None = None |
|
|
stream_url: str | None = None |
|
|
component_id: str | None = None |
|
|
files: list[File] = [] |
|
|
type: str |
|
|
|
|
|
@field_validator("files", mode="before") |
|
|
@classmethod |
|
|
def validate_files(cls, files): |
|
|
"""Validate files.""" |
|
|
if not files: |
|
|
return files |
|
|
|
|
|
for file in files: |
|
|
if not isinstance(file, dict): |
|
|
msg = "Files must be a list of dictionaries." |
|
|
raise ValueError(msg) |
|
|
|
|
|
if not all(key in file for key in ["path", "name", "type"]): |
|
|
|
|
|
|
|
|
path = file.get("path") |
|
|
if not path: |
|
|
msg = "File path is required." |
|
|
raise ValueError(msg) |
|
|
|
|
|
name = file.get("name") |
|
|
if not name: |
|
|
name = path.split("/")[-1] |
|
|
file["name"] = name |
|
|
type_ = file.get("type") |
|
|
if not type_: |
|
|
|
|
|
extension = path.split(".")[-1] |
|
|
file_types = set(TEXT_FILE_TYPES + IMG_FILE_TYPES) |
|
|
if extension and extension in file_types: |
|
|
type_ = extension |
|
|
else: |
|
|
for file_type in file_types: |
|
|
if file_type in path: |
|
|
type_ = file_type |
|
|
break |
|
|
if not type_: |
|
|
msg = "File type is required." |
|
|
raise ValueError(msg) |
|
|
file["type"] = type_ |
|
|
|
|
|
return files |
|
|
|
|
|
@classmethod |
|
|
def from_message( |
|
|
cls, |
|
|
message: BaseMessage, |
|
|
sender: str | None = MESSAGE_SENDER_AI, |
|
|
sender_name: str | None = MESSAGE_SENDER_NAME_AI, |
|
|
): |
|
|
"""Build chat output response from message.""" |
|
|
content = message.content |
|
|
return cls(message=content, sender=sender, sender_name=sender_name) |
|
|
|
|
|
@model_validator(mode="after") |
|
|
def validate_message(self): |
|
|
"""Validate message.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.sender != MESSAGE_SENDER_AI: |
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
message = self.message.replace("\n\n", "\n") |
|
|
self.message = message.replace("\n", "\n\n") |
|
|
return self |
|
|
|
|
|
|
|
|
class DataOutputResponse(BaseModel): |
|
|
"""Data output response schema.""" |
|
|
|
|
|
data: list[dict | None] |
|
|
|
|
|
|
|
|
class ContainsEnumMeta(enum.EnumMeta): |
|
|
def __contains__(cls, item) -> bool: |
|
|
try: |
|
|
cls(item) |
|
|
except ValueError: |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|