from langchain_core.documents import Document from typing import Union, Dict, Any from langchain_core.messages import BaseMessage, trim_messages from langchain_core.runnables import RunnableLambda from langgraph.prebuilt import ToolNode from langchain_core.messages import ToolMessage import base64 from fastapi import UploadFile from typing import TypeVar from youtube_transcript_api import YouTubeTranscriptApi from youtube_comment_downloader import YoutubeCommentDownloader from src.utils.logger import logger import requests State = TypeVar("State", bound=Dict[str, Any]) def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int: if isinstance(messages, list): return sum(len(str(message.content).split()) for message in messages) return len(str(messages.content).split()) def convert_list_context_source_to_str(contexts: list[Document]): formatted_str = "" for i, context in enumerate(contexts): formatted_str += f"Document index {i}:\nContent: {context.page_content}\n" formatted_str += "----------------------------------------------\n\n" return formatted_str def trim_messages_function(messages: list[BaseMessage], max_tokens: int = 100000): if len(messages) <= 1: return messages messages = trim_messages( messages, strategy="last", token_counter=fake_token_counter, max_tokens=max_tokens, start_on="human", # end_on="ai", include_system=False, allow_partial=False, ) return messages def create_tool_node_with_fallback(tools: list) -> dict: return ToolNode(tools).with_fallbacks( [RunnableLambda(handle_tool_error)], exception_key="error" ) def handle_tool_error(state: State) -> dict: error = state.get("error") tool_messages = state["messages"][-1] return { "messages": [ ToolMessage( content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"], ) for tc in tool_messages.tool_calls ] } async def preprocess_messages(query: str, attachs: list[UploadFile]): messages: dict[str, list[dict]] = { "role": "user", "content": [], } if query: messages["content"].append( { "type": "text", "text": query, } ) if attachs: for attach in attachs: if ( attach.content_type == "image/jpeg" or attach.content_type == "image/png" ): content = await attach.read() encoded_string = base64.b64encode(content).decode("utf-8") messages["content"].append( { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{encoded_string}", }, } ) if attach.content_type == "application/pdf": content = await attach.read() encoded_string = base64.b64encode(content).decode("utf-8") messages["content"].append( { "type": "file", "source_type": "base64", "mime_type": "application/pdf", "data": f"{encoded_string}", "citations": {"enabled": True}, } ) return messages # def extract_transcript(video_link: str): # ytt_api = YouTubeTranscriptApi() # # extract video id from video link # video_id = video_link.split("v=")[1] # transcript = ytt_api.fetch(video_id) # transcript_str = "" # for trans in transcript: # transcript_str += trans.text + " " # logger.info(f"Transcript: {transcript_str}") # return transcript_str import os def extract_transcript(video_link: str): try: # extract video id from video link video_id = video_link.split("v=")[1] # Call Supadata API url = f"https://api.supadata.ai/v1/youtube/transcript" headers = { "x-api-key": os.getenv("SUPADATA_API_KEY") } params = { "videoId": video_id } response = requests.get(url, headers=headers, params=params) response.raise_for_status() # Raise exception for non-200 status codes data = response.json() text = "" for item in data["content"]: if "text" in item: text += item["text"] + " " logger.info(f"Transcript: {text}") return text except Exception as e: logger.error(f"Failed to extract transcript: {str(e)}") raise Exception(f"Failed to extract transcript: {str(e)}") def extract_comment(video_link: str): ytd_api = YoutubeCommentDownloader() comments = ytd_api.get_comments_from_url(video_link) comments_str = "" for comment in comments: comments_str += comment["text"] + " " logger.info(f"Comments: {comments_str}") return comments_str