ThienLe's picture
Update utils.py
dc783ad verified
import tempfile
from yt_dlp import YoutubeDL
import os
from langchain_openai import ChatOpenAI
from langchain_community.tools import DuckDuckGoSearchRun
from openai import OpenAI
import cv2
import base64
import tempfile
import requests
from typing import List
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.tools import Tool, StructuredTool
from langchain_experimental.utilities import PythonREPL
from datasets import load_dataset
from huggingface_hub import snapshot_download
import base64
from pathlib import Path
import pandas as pd
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
import json
data_dir = snapshot_download(repo_id="gaia-benchmark/GAIA", repo_type="dataset")
dataset = load_dataset(data_dir, "2023_level1", split="test")
python_repl = PythonREPL()
python_tool = Tool(
name="python_tool",
description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`. Used for question about math or related compuation problem.",
func=python_repl.run
)
os.environ["OPENAI_API_KEY"] = "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"
llm = ChatOpenAI(model="gpt-4o", temperature=0)
client = OpenAI()
search_tool = DuckDuckGoSearchRun()
def image_to_base64(image_path: str) -> str:
"""
Read an image file (.png, .jpg, .jpeg) and return a base64-encoded string.
"""
path = Path(image_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {image_path}")
if path.suffix.lower() not in {".png", ".jpg", ".jpeg"}:
raise ValueError("Only .png, .jpg, and .jpeg files are supported")
with open(path, "rb") as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode("utf-8")
def mp3_to_base64(mp3_path: str) -> str:
"""
Read an .mp3 file and return a base64-encoded string.
"""
path = Path(mp3_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {mp3_path}")
if path.suffix.lower() != ".mp3":
raise ValueError("Only .mp3 files are supported")
with open(path, "rb") as f:
audio_bytes = f.read()
return base64.b64encode(audio_bytes).decode("utf-8")
def read_xlsx_to_df(xlsx_path: str, sheet_name=0) -> pd.DataFrame:
"""
Read an .xlsx file into a pandas DataFrame.
"""
path = Path(xlsx_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {xlsx_path}")
if path.suffix.lower() != ".xlsx":
raise ValueError("Only .xlsx files are supported")
return pd.read_excel(path, sheet_name=sheet_name)
def download_video(video_url: str) -> str:
"""
Download a video from YouTube (or any yt-dlp supported site)
and return the local file path.
"""
tmp_dir = tempfile.mkdtemp()
output_path = f"{tmp_dir}/video.%(ext)s"
ydl_opts = {
"format": "mp4/bestvideo+bestaudio/best",
"outtmpl": output_path,
"merge_output_format": "mp4",
"quiet": True,
"no_warnings": True,
}
with YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(video_url, download=True)
downloaded_path = ydl.prepare_filename(info)
return downloaded_path
def extract_frames(
video_path: str,
fps: int = 1,
max_frames: int = 32,
) -> List[str]:
"""
Extract frames from video.
Returns list of base64-encoded JPEG images.
"""
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
frame_interval = max(int(video_fps // fps), 1)
frames_b64 = []
frame_idx = 0
extracted = 0
while cap.isOpened() and extracted < max_frames:
ret, frame = cap.read()
if not ret:
break
if frame_idx % frame_interval == 0:
_, buffer = cv2.imencode(".jpg", frame)
b64 = base64.b64encode(buffer).decode("utf-8")
frames_b64.append(b64)
extracted += 1
frame_idx += 1
cap.release()
return frames_b64
# ---------------------------
# LLM Reasoning
# ---------------------------
def answer_question_from_video(
video_url: str,
question: str,
model: str = "gpt-4o",
frame_fps: int = 1,
max_frames: int = 32,
) -> str:
"""
Answer question based on video data input
Args:
video_url (str): url of video.
question (str): question to be answered
model (str): gpt-4 model name used to generate asnwer.
frame_fps (int): Sample frame rate for source video.
max_frames (int): Max to read for generating answer.
Returns:
answer (str): generated answer.
"""
# 1. Download
video_path = download_video(video_url)
print("video_path:", video_path)
try:
# 2. Preprocess video → frames
frames = extract_frames(
video_path,
fps=frame_fps,
max_frames=max_frames,
)
if not frames:
raise RuntimeError("No frames extracted from video")
# 3. Build multimodal prompt
content = [
{
"type": "text",
"text": (
"You are given a sequence of video frames sampled over time.\n"
"Answer the user's question based on the visual content."
),
}
]
for frame in frames:
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{frame}"
},
}
)
content.append(
{
"type": "text",
"text": f"Question: {question}",
}
)
# 4. Call LLM
llm = ChatOpenAI(
model=model,
temperature=0,
)
response = llm.invoke(
[
HumanMessage(
content=content
)
]
)
return response.content
finally:
# cleanup
if os.path.exists(video_path):
os.remove(video_path)
video_answer_tool = StructuredTool.from_function(
name="video_answer",
func=answer_question_from_video,
description="used to answer questions based on given video input"
)
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
file_path: str | None
question: str
tools = [search_tool, python_tool, video_answer_tool]
llm = llm.bind_tools(tools=tools)
def assistant(state: State):
print(state)
if len(state["messages"]) == 0:
context = None
base64_string = None
mime_type = None
content = [
{
"type": "text",
"text": (
state["question"]
),
}
]
if state["file_path"]:
file_path = os.path.join(data_dir, "2023", "validation",state["file_path"])
extension = state["file_path"].split(".")[1]
if extension in ["jpg", "png"]:
base64_string = image_to_base64(file_path)
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_string}"
}
}
)
elif extension in ["mp3"]:
# base64_string = mp3_to_base64(file_path)
# # content.append(
# # {
# # "type": "audio",
# # "mime_type": "audio/wav",
# # "base64": base64_string
# # }
# # )
# content.append(
# {
# "type": "input_audio",
# "input_audio": {"data": base64_string, "format": "wav"},
# }
# )
with open(file_path, "rb") as audio_file:
transcription = client.audio.transcriptions.create(
model="whisper-1",
file=audio_file
)
content.append(
{
"type": "text",
"text": f"Audio transcription: {transcription}"
}
)
elif extension in ["xlsx"]:
df = read_xlsx_to_df(file_path)
context = df.to_json(
orient="records",
force_ascii=False,
indent=2
)
content.append(
{
"type": "text",
"text": context
}
)
else:
with open(file_path, "r", encoding="utf-8") as f:
context = f.read()
content.append(
{
"type": "text",
"text": context
}
)
human_message = HumanMessage(content=content)
state["messages"].append(human_message)
system_message = SystemMessage(
content="""
You are a general AI assistant. I will ask you a question. Use provided tools to complete your task if nesscessary (Maximum 5 tool calls step), when you got an answer do not use any tool and give answer in following format
Report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""
)
if len(state["messages"]) <= 20:
response = llm.invoke([system_message] + state["messages"])
else:
response = AIMessage(content="FINAL ANSWER: I don't know")
return {
"messages": response
}
builder = StateGraph(State)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition
)
builder.add_edge("tools", "assistant")
agent = builder.compile()