Spaces:
Sleeping
Sleeping
| import tempfile | |
| from yt_dlp import YoutubeDL | |
| import os | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from openai import OpenAI | |
| import cv2 | |
| import base64 | |
| import tempfile | |
| import requests | |
| from typing import List | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage | |
| from langchain_core.tools import Tool, StructuredTool | |
| from langchain_experimental.utilities import PythonREPL | |
| from datasets import load_dataset | |
| from huggingface_hub import snapshot_download | |
| import base64 | |
| from pathlib import Path | |
| import pandas as pd | |
| from typing import TypedDict, Annotated | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.graph import START, StateGraph | |
| from langgraph.prebuilt import tools_condition | |
| import json | |
| data_dir = snapshot_download(repo_id="gaia-benchmark/GAIA", repo_type="dataset") | |
| dataset = load_dataset(data_dir, "2023_level1", split="test") | |
| python_repl = PythonREPL() | |
| python_tool = Tool( | |
| name="python_tool", | |
| description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`. Used for question about math or related compuation problem.", | |
| func=python_repl.run | |
| ) | |
| os.environ["OPENAI_API_KEY"] = "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" | |
| llm = ChatOpenAI(model="gpt-4o", temperature=0) | |
| client = OpenAI() | |
| search_tool = DuckDuckGoSearchRun() | |
| def image_to_base64(image_path: str) -> str: | |
| """ | |
| Read an image file (.png, .jpg, .jpeg) and return a base64-encoded string. | |
| """ | |
| path = Path(image_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"File not found: {image_path}") | |
| if path.suffix.lower() not in {".png", ".jpg", ".jpeg"}: | |
| raise ValueError("Only .png, .jpg, and .jpeg files are supported") | |
| with open(path, "rb") as f: | |
| image_bytes = f.read() | |
| return base64.b64encode(image_bytes).decode("utf-8") | |
| def mp3_to_base64(mp3_path: str) -> str: | |
| """ | |
| Read an .mp3 file and return a base64-encoded string. | |
| """ | |
| path = Path(mp3_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"File not found: {mp3_path}") | |
| if path.suffix.lower() != ".mp3": | |
| raise ValueError("Only .mp3 files are supported") | |
| with open(path, "rb") as f: | |
| audio_bytes = f.read() | |
| return base64.b64encode(audio_bytes).decode("utf-8") | |
| def read_xlsx_to_df(xlsx_path: str, sheet_name=0) -> pd.DataFrame: | |
| """ | |
| Read an .xlsx file into a pandas DataFrame. | |
| """ | |
| path = Path(xlsx_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"File not found: {xlsx_path}") | |
| if path.suffix.lower() != ".xlsx": | |
| raise ValueError("Only .xlsx files are supported") | |
| return pd.read_excel(path, sheet_name=sheet_name) | |
| def download_video(video_url: str) -> str: | |
| """ | |
| Download a video from YouTube (or any yt-dlp supported site) | |
| and return the local file path. | |
| """ | |
| tmp_dir = tempfile.mkdtemp() | |
| output_path = f"{tmp_dir}/video.%(ext)s" | |
| ydl_opts = { | |
| "format": "mp4/bestvideo+bestaudio/best", | |
| "outtmpl": output_path, | |
| "merge_output_format": "mp4", | |
| "quiet": True, | |
| "no_warnings": True, | |
| } | |
| with YoutubeDL(ydl_opts) as ydl: | |
| info = ydl.extract_info(video_url, download=True) | |
| downloaded_path = ydl.prepare_filename(info) | |
| return downloaded_path | |
| def extract_frames( | |
| video_path: str, | |
| fps: int = 1, | |
| max_frames: int = 32, | |
| ) -> List[str]: | |
| """ | |
| Extract frames from video. | |
| Returns list of base64-encoded JPEG images. | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| video_fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = max(int(video_fps // fps), 1) | |
| frames_b64 = [] | |
| frame_idx = 0 | |
| extracted = 0 | |
| while cap.isOpened() and extracted < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_idx % frame_interval == 0: | |
| _, buffer = cv2.imencode(".jpg", frame) | |
| b64 = base64.b64encode(buffer).decode("utf-8") | |
| frames_b64.append(b64) | |
| extracted += 1 | |
| frame_idx += 1 | |
| cap.release() | |
| return frames_b64 | |
| # --------------------------- | |
| # LLM Reasoning | |
| # --------------------------- | |
| def answer_question_from_video( | |
| video_url: str, | |
| question: str, | |
| model: str = "gpt-4o", | |
| frame_fps: int = 1, | |
| max_frames: int = 32, | |
| ) -> str: | |
| """ | |
| Answer question based on video data input | |
| Args: | |
| video_url (str): url of video. | |
| question (str): question to be answered | |
| model (str): gpt-4 model name used to generate asnwer. | |
| frame_fps (int): Sample frame rate for source video. | |
| max_frames (int): Max to read for generating answer. | |
| Returns: | |
| answer (str): generated answer. | |
| """ | |
| # 1. Download | |
| video_path = download_video(video_url) | |
| print("video_path:", video_path) | |
| try: | |
| # 2. Preprocess video → frames | |
| frames = extract_frames( | |
| video_path, | |
| fps=frame_fps, | |
| max_frames=max_frames, | |
| ) | |
| if not frames: | |
| raise RuntimeError("No frames extracted from video") | |
| # 3. Build multimodal prompt | |
| content = [ | |
| { | |
| "type": "text", | |
| "text": ( | |
| "You are given a sequence of video frames sampled over time.\n" | |
| "Answer the user's question based on the visual content." | |
| ), | |
| } | |
| ] | |
| for frame in frames: | |
| content.append( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{frame}" | |
| }, | |
| } | |
| ) | |
| content.append( | |
| { | |
| "type": "text", | |
| "text": f"Question: {question}", | |
| } | |
| ) | |
| # 4. Call LLM | |
| llm = ChatOpenAI( | |
| model=model, | |
| temperature=0, | |
| ) | |
| response = llm.invoke( | |
| [ | |
| HumanMessage( | |
| content=content | |
| ) | |
| ] | |
| ) | |
| return response.content | |
| finally: | |
| # cleanup | |
| if os.path.exists(video_path): | |
| os.remove(video_path) | |
| video_answer_tool = StructuredTool.from_function( | |
| name="video_answer", | |
| func=answer_question_from_video, | |
| description="used to answer questions based on given video input" | |
| ) | |
| class State(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| file_path: str | None | |
| question: str | |
| tools = [search_tool, python_tool, video_answer_tool] | |
| llm = llm.bind_tools(tools=tools) | |
| def assistant(state: State): | |
| print(state) | |
| if len(state["messages"]) == 0: | |
| context = None | |
| base64_string = None | |
| mime_type = None | |
| content = [ | |
| { | |
| "type": "text", | |
| "text": ( | |
| state["question"] | |
| ), | |
| } | |
| ] | |
| if state["file_path"]: | |
| file_path = os.path.join(data_dir, "2023", "validation",state["file_path"]) | |
| extension = state["file_path"].split(".")[1] | |
| if extension in ["jpg", "png"]: | |
| base64_string = image_to_base64(file_path) | |
| content.append( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_string}" | |
| } | |
| } | |
| ) | |
| elif extension in ["mp3"]: | |
| # base64_string = mp3_to_base64(file_path) | |
| # # content.append( | |
| # # { | |
| # # "type": "audio", | |
| # # "mime_type": "audio/wav", | |
| # # "base64": base64_string | |
| # # } | |
| # # ) | |
| # content.append( | |
| # { | |
| # "type": "input_audio", | |
| # "input_audio": {"data": base64_string, "format": "wav"}, | |
| # } | |
| # ) | |
| with open(file_path, "rb") as audio_file: | |
| transcription = client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file | |
| ) | |
| content.append( | |
| { | |
| "type": "text", | |
| "text": f"Audio transcription: {transcription}" | |
| } | |
| ) | |
| elif extension in ["xlsx"]: | |
| df = read_xlsx_to_df(file_path) | |
| context = df.to_json( | |
| orient="records", | |
| force_ascii=False, | |
| indent=2 | |
| ) | |
| content.append( | |
| { | |
| "type": "text", | |
| "text": context | |
| } | |
| ) | |
| else: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| context = f.read() | |
| content.append( | |
| { | |
| "type": "text", | |
| "text": context | |
| } | |
| ) | |
| human_message = HumanMessage(content=content) | |
| state["messages"].append(human_message) | |
| system_message = SystemMessage( | |
| content=""" | |
| You are a general AI assistant. I will ask you a question. Use provided tools to complete your task if nesscessary (Maximum 5 tool calls step), when you got an answer do not use any tool and give answer in following format | |
| Report your thoughts, and finish your answer with the following template: | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
| If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
| If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| """ | |
| ) | |
| if len(state["messages"]) <= 20: | |
| response = llm.invoke([system_message] + state["messages"]) | |
| else: | |
| response = AIMessage(content="FINAL ANSWER: I don't know") | |
| return { | |
| "messages": response | |
| } | |
| builder = StateGraph(State) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| agent = builder.compile() |