from langchain_core.documents import Document from typing import Union, Dict, Any from langchain_core.messages import BaseMessage, trim_messages from langchain_core.runnables import RunnableLambda from langgraph.prebuilt import ToolNode from langchain_core.messages import ToolMessage import base64 from fastapi import UploadFile from typing import TypeVar from youtube_transcript_api import YouTubeTranscriptApi from youtube_comment_downloader import YoutubeCommentDownloader from src.utils.logger import logger State = TypeVar("State", bound=Dict[str, Any]) def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int: if isinstance(messages, list): return sum(len(str(message.content).split()) for message in messages) return len(str(messages.content).split()) def convert_list_context_source_to_str(contexts: list[Document]): formatted_str = "" for i, context in enumerate(contexts): formatted_str += f"Document index {i}:\nContent: {context.page_content}\n" formatted_str += "----------------------------------------------\n\n" return formatted_str def trim_messages_function(messages: list[BaseMessage], max_tokens: int = 100000): if len(messages) <= 1: return messages messages = trim_messages( messages, strategy="last", token_counter=fake_token_counter, max_tokens=max_tokens, start_on="human", # end_on="ai", include_system=False, allow_partial=False, ) return messages def create_tool_node_with_fallback(tools: list) -> dict: return ToolNode(tools).with_fallbacks( [RunnableLambda(handle_tool_error)], exception_key="error" ) def handle_tool_error(state: State) -> dict: error = state.get("error") tool_messages = state["messages"][-1] return { "messages": [ ToolMessage( content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"], ) for tc in tool_messages.tool_calls ] } async def preprocess_messages(query: str, attachs: list[UploadFile]): messages: dict[str, list[dict]] = { "role": "user", "content": [], } if query: messages["content"].append( { "type": "text", "text": query, } ) if attachs: for attach in attachs: if ( attach.content_type == "image/jpeg" or attach.content_type == "image/png" ): content = await attach.read() encoded_string = base64.b64encode(content).decode("utf-8") messages["content"].append( { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{encoded_string}", }, } ) if attach.content_type == "application/pdf": content = await attach.read() encoded_string = base64.b64encode(content).decode("utf-8") messages["content"].append( { "type": "file", "source_type": "base64", "mime_type": "application/pdf", "data": f"{encoded_string}", "citations": {"enabled": True}, } ) return messages def extract_transcript(video_link: str): ytt_api = YouTubeTranscriptApi() # extract video id from video link video_id = video_link.split("v=")[1] transcript = ytt_api.fetch(video_id) transcript_str = "" for trans in transcript: transcript_str += trans.text + " " logger.info(f"Transcript: {transcript_str}") return transcript_str def extract_comment(video_link: str): ytd_api = YoutubeCommentDownloader() comments = ytd_api.get_comments_from_url(video_link) comments_str = "" for comment in comments: comments_str += comment["text"] + " " logger.info(f"Comments: {comments_str}") return comments_str