Spaces:
Sleeping
Sleeping
| from langchain_core.documents import Document | |
| from typing import Union, Dict, Any | |
| from langchain_core.messages import BaseMessage, trim_messages | |
| from langchain_core.runnables import RunnableLambda | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.messages import ToolMessage | |
| import base64 | |
| from fastapi import UploadFile | |
| from typing import TypeVar | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from youtube_comment_downloader import YoutubeCommentDownloader | |
| from src.utils.logger import logger | |
| import requests | |
| import os | |
| State = TypeVar("State", bound=Dict[str, Any]) | |
| def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int: | |
| if isinstance(messages, list): | |
| return sum(len(str(message.content).split()) for message in messages) | |
| return len(str(messages.content).split()) | |
| def convert_list_context_source_to_str(contexts: list[Document]): | |
| formatted_str = "" | |
| for i, context in enumerate(contexts): | |
| formatted_str += f"Document index {i}:\nContent: {context.page_content}\n" | |
| formatted_str += "----------------------------------------------\n\n" | |
| return formatted_str | |
| def trim_messages_function(messages: list[BaseMessage], max_tokens: int = 100000): | |
| if len(messages) <= 1: | |
| return messages | |
| messages = trim_messages( | |
| messages, | |
| strategy="last", | |
| token_counter=fake_token_counter, | |
| max_tokens=max_tokens, | |
| start_on="human", | |
| # end_on="ai", | |
| include_system=False, | |
| allow_partial=False, | |
| ) | |
| return messages | |
| def create_tool_node_with_fallback(tools: list) -> dict: | |
| return ToolNode(tools).with_fallbacks( | |
| [RunnableLambda(handle_tool_error)], exception_key="error" | |
| ) | |
| def handle_tool_error(state: State) -> dict: | |
| error = state.get("error") | |
| tool_messages = state["messages"][-1] | |
| return { | |
| "messages": [ | |
| ToolMessage( | |
| content=f"Error: {repr(error)}\n please fix your mistakes.", | |
| tool_call_id=tc["id"], | |
| ) | |
| for tc in tool_messages.tool_calls | |
| ] | |
| } | |
| async def preprocess_messages(query: str, attachs: list[UploadFile]): | |
| messages: dict[str, list[dict]] = { | |
| "role": "user", | |
| "content": [], | |
| } | |
| if query: | |
| messages["content"].append( | |
| { | |
| "type": "text", | |
| "text": query, | |
| } | |
| ) | |
| if attachs: | |
| for attach in attachs: | |
| if ( | |
| attach.content_type == "image/jpeg" | |
| or attach.content_type == "image/png" | |
| ): | |
| content = await attach.read() | |
| encoded_string = base64.b64encode(content).decode("utf-8") | |
| messages["content"].append( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{encoded_string}", | |
| }, | |
| } | |
| ) | |
| if attach.content_type == "application/pdf": | |
| content = await attach.read() | |
| encoded_string = base64.b64encode(content).decode("utf-8") | |
| messages["content"].append( | |
| { | |
| "type": "file", | |
| "source_type": "base64", | |
| "mime_type": "application/pdf", | |
| "data": f"{encoded_string}", | |
| "citations": {"enabled": True}, | |
| } | |
| ) | |
| return messages | |
| import re | |
| def extract_video_id_regex(url): | |
| """ | |
| Extracts the YouTube video ID using a regular expression. | |
| Returns: | |
| The video ID as a string if found, otherwise None. | |
| """ | |
| pattern = r"(?:v=|\/)([0-9A-Za-z_-]{11})(?:\?|&|$)" | |
| match = re.search(pattern, url) | |
| return match.group(1) if match else None | |
| # def extract_transcript(video_link: str): | |
| # ytt_api = YouTubeTranscriptApi() | |
| # # extract video id from video link | |
| # video_id = extract_video_id_regex(video_link) | |
| # logger.info(f"Video ID: {video_id}") | |
| # transcript = ytt_api.fetch(video_id) | |
| # transcript_str = "" | |
| # for trans in transcript: | |
| # transcript_str += trans.text + " " | |
| # logger.info(f"Transcript: {transcript_str}") | |
| # return transcript_str | |
| def extract_transcript(video_link: str): | |
| try: | |
| # extract video id from video link | |
| video_id = extract_video_id_regex(video_link) | |
| api_key = os.getenv("SUPADATA_API_KEY") | |
| if not api_key: | |
| raise ValueError("SUPADATA_API_KEY environment variable is not set") | |
| # Call Supadata API | |
| url = f"https://api.supadata.ai/v1/youtube/transcript" | |
| headers = {"x-api-key": api_key} | |
| params = {"videoId": video_id} | |
| response = requests.get(url, headers=headers, params=params) | |
| response.raise_for_status() # Raise exception for non-200 status codes | |
| data = response.json() | |
| logger.info(f"Data: {data}") | |
| if not data.get("content"): | |
| raise ValueError("No transcript content found in the API response") | |
| text = "" | |
| for item in data["content"]: | |
| if "text" in item: | |
| text += item["text"] + " " | |
| logger.info(f"Transcript: {text}") | |
| return text | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"API request failed: {str(e)}") | |
| raise | |
| except ValueError as e: | |
| logger.error(str(e)) | |
| raise | |
| except Exception as e: | |
| logger.error(f"Failed to extract transcript: {str(e)}") | |
| raise | |
| def extract_comment(video_link: str): | |
| try: | |
| ytd_api = YoutubeCommentDownloader() | |
| comments = ytd_api.get_comments_from_url(video_link) | |
| comments_str = "" | |
| for comment in comments: | |
| comments_str += comment["text"] + " " | |
| logger.info(f"Comments: {comments_str}") | |
| return comments_str | |
| except Exception as e: | |
| logger.error(f"Failed to extract comments: {str(e)}") | |
| return "" | |