from langchain_core.documents import Document from typing import Union, Dict, Any from langchain_core.messages import BaseMessage, trim_messages from langchain_core.runnables import RunnableLambda from langgraph.prebuilt import ToolNode from langchain_core.messages import ToolMessage import base64 from fastapi import UploadFile from typing import TypeVar from youtube_transcript_api import YouTubeTranscriptApi from youtube_comment_downloader import YoutubeCommentDownloader from src.utils.logger import logger import requests import os State = TypeVar("State", bound=Dict[str, Any]) def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int: if isinstance(messages, list): return sum(len(str(message.content).split()) for message in messages) return len(str(messages.content).split()) def convert_list_context_source_to_str(contexts: list[Document]): formatted_str = "" for i, context in enumerate(contexts): formatted_str += f"Document index {i}:\nContent: {context.page_content}\n" formatted_str += "----------------------------------------------\n\n" return formatted_str def trim_messages_function(messages: list[BaseMessage], max_tokens: int = 100000): if len(messages) <= 1: return messages messages = trim_messages( messages, strategy="last", token_counter=fake_token_counter, max_tokens=max_tokens, start_on="human", # end_on="ai", include_system=False, allow_partial=False, ) return messages def create_tool_node_with_fallback(tools: list) -> dict: return ToolNode(tools).with_fallbacks( [RunnableLambda(handle_tool_error)], exception_key="error" ) def handle_tool_error(state: State) -> dict: error = state.get("error") tool_messages = state["messages"][-1] return { "messages": [ ToolMessage( content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"], ) for tc in tool_messages.tool_calls ] } async def preprocess_messages(query: str, attachs: list[UploadFile]): messages: dict[str, list[dict]] = { "role": "user", "content": [], } if query: messages["content"].append( { "type": "text", "text": query, } ) if attachs: for attach in attachs: if ( attach.content_type == "image/jpeg" or attach.content_type == "image/png" ): content = await attach.read() encoded_string = base64.b64encode(content).decode("utf-8") messages["content"].append( { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{encoded_string}", }, } ) if attach.content_type == "application/pdf": content = await attach.read() encoded_string = base64.b64encode(content).decode("utf-8") messages["content"].append( { "type": "file", "source_type": "base64", "mime_type": "application/pdf", "data": f"{encoded_string}", "citations": {"enabled": True}, } ) return messages import re def extract_video_id_regex(url): """ Extracts the YouTube video ID using a regular expression. Returns: The video ID as a string if found, otherwise None. """ pattern = r"(?:v=|\/)([0-9A-Za-z_-]{11})(?:\?|&|$)" match = re.search(pattern, url) return match.group(1) if match else None # def extract_transcript(video_link: str): # ytt_api = YouTubeTranscriptApi() # # extract video id from video link # video_id = extract_video_id_regex(video_link) # logger.info(f"Video ID: {video_id}") # transcript = ytt_api.fetch(video_id) # transcript_str = "" # for trans in transcript: # transcript_str += trans.text + " " # logger.info(f"Transcript: {transcript_str}") # return transcript_str def extract_transcript(video_link: str): try: # extract video id from video link video_id = extract_video_id_regex(video_link) api_key = os.getenv("SUPADATA_API_KEY") if not api_key: raise ValueError("SUPADATA_API_KEY environment variable is not set") # Call Supadata API url = f"https://api.supadata.ai/v1/youtube/transcript" headers = {"x-api-key": api_key} params = {"videoId": video_id} response = requests.get(url, headers=headers, params=params) response.raise_for_status() # Raise exception for non-200 status codes data = response.json() logger.info(f"Data: {data}") if not data.get("content"): raise ValueError("No transcript content found in the API response") text = "" for item in data["content"]: if "text" in item: text += item["text"] + " " logger.info(f"Transcript: {text}") return text except requests.exceptions.RequestException as e: logger.error(f"API request failed: {str(e)}") raise except ValueError as e: logger.error(str(e)) raise except Exception as e: logger.error(f"Failed to extract transcript: {str(e)}") raise def extract_comment(video_link: str): try: ytd_api = YoutubeCommentDownloader() comments = ytd_api.get_comments_from_url(video_link) comments_str = "" for comment in comments: comments_str += comment["text"] + " " logger.info(f"Comments: {comments_str}") return comments_str except Exception as e: logger.error(f"Failed to extract comments: {str(e)}") return ""