ABAO77's picture
Upload 67 files
d62d2dd verified
from langchain_core.documents import Document
from typing import Union, Dict, Any
from langchain_core.messages import BaseMessage, trim_messages
from langchain_core.runnables import RunnableLambda
from langgraph.prebuilt import ToolNode
from langchain_core.messages import ToolMessage
import base64
from fastapi import UploadFile
from typing import TypeVar
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_comment_downloader import YoutubeCommentDownloader
from src.utils.logger import logger
import requests
import os
State = TypeVar("State", bound=Dict[str, Any])
def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int:
if isinstance(messages, list):
return sum(len(str(message.content).split()) for message in messages)
return len(str(messages.content).split())
def convert_list_context_source_to_str(contexts: list[Document]):
formatted_str = ""
for i, context in enumerate(contexts):
formatted_str += f"Document index {i}:\nContent: {context.page_content}\n"
formatted_str += "----------------------------------------------\n\n"
return formatted_str
def trim_messages_function(messages: list[BaseMessage], max_tokens: int = 100000):
if len(messages) <= 1:
return messages
messages = trim_messages(
messages,
strategy="last",
token_counter=fake_token_counter,
max_tokens=max_tokens,
start_on="human",
# end_on="ai",
include_system=False,
allow_partial=False,
)
return messages
def create_tool_node_with_fallback(tools: list) -> dict:
return ToolNode(tools).with_fallbacks(
[RunnableLambda(handle_tool_error)], exception_key="error"
)
def handle_tool_error(state: State) -> dict:
error = state.get("error")
tool_messages = state["messages"][-1]
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_messages.tool_calls
]
}
async def preprocess_messages(query: str, attachs: list[UploadFile]):
messages: dict[str, list[dict]] = {
"role": "user",
"content": [],
}
if query:
messages["content"].append(
{
"type": "text",
"text": query,
}
)
if attachs:
for attach in attachs:
if (
attach.content_type == "image/jpeg"
or attach.content_type == "image/png"
):
content = await attach.read()
encoded_string = base64.b64encode(content).decode("utf-8")
messages["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_string}",
},
}
)
if attach.content_type == "application/pdf":
content = await attach.read()
encoded_string = base64.b64encode(content).decode("utf-8")
messages["content"].append(
{
"type": "file",
"source_type": "base64",
"mime_type": "application/pdf",
"data": f"{encoded_string}",
"citations": {"enabled": True},
}
)
return messages
import re
def extract_video_id_regex(url):
"""
Extracts the YouTube video ID using a regular expression.
Returns:
The video ID as a string if found, otherwise None.
"""
pattern = r"(?:v=|\/)([0-9A-Za-z_-]{11})(?:\?|&|$)"
match = re.search(pattern, url)
return match.group(1) if match else None
# def extract_transcript(video_link: str):
# ytt_api = YouTubeTranscriptApi()
# # extract video id from video link
# video_id = extract_video_id_regex(video_link)
# logger.info(f"Video ID: {video_id}")
# transcript = ytt_api.fetch(video_id)
# transcript_str = ""
# for trans in transcript:
# transcript_str += trans.text + " "
# logger.info(f"Transcript: {transcript_str}")
# return transcript_str
def extract_transcript(video_link: str):
try:
# extract video id from video link
video_id = extract_video_id_regex(video_link)
api_key = os.getenv("SUPADATA_API_KEY")
if not api_key:
raise ValueError("SUPADATA_API_KEY environment variable is not set")
# Call Supadata API
url = f"https://api.supadata.ai/v1/youtube/transcript"
headers = {"x-api-key": api_key}
params = {"videoId": video_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status() # Raise exception for non-200 status codes
data = response.json()
logger.info(f"Data: {data}")
if not data.get("content"):
raise ValueError("No transcript content found in the API response")
text = ""
for item in data["content"]:
if "text" in item:
text += item["text"] + " "
logger.info(f"Transcript: {text}")
return text
except requests.exceptions.RequestException as e:
logger.error(f"API request failed: {str(e)}")
raise
except ValueError as e:
logger.error(str(e))
raise
except Exception as e:
logger.error(f"Failed to extract transcript: {str(e)}")
raise
def extract_comment(video_link: str):
try:
ytd_api = YoutubeCommentDownloader()
comments = ytd_api.get_comments_from_url(video_link)
comments_str = ""
for comment in comments:
comments_str += comment["text"] + " "
logger.info(f"Comments: {comments_str}")
return comments_str
except Exception as e:
logger.error(f"Failed to extract comments: {str(e)}")
return ""