D3MI4N commited on
Commit
ab62c9e
Β·
1 Parent(s): 145b86d

include excel reader and audio tools

Browse files
Files changed (4) hide show
  1. .gitignore +4 -0
  2. fetch_gaia_audio.py +76 -0
  3. langgraph3.py +143 -0
  4. requirements.txt +2 -0
.gitignore CHANGED
@@ -26,3 +26,7 @@ config.yaml
26
  # 6) Any Docker or Kubernetes local files
27
  docker-compose.override.yml
28
  *.log
 
 
 
 
 
26
  # 6) Any Docker or Kubernetes local files
27
  docker-compose.override.yml
28
  *.log
29
+
30
+ # 7) Test files
31
+ test_sales.xlsx
32
+ test.wav
fetch_gaia_audio.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fetch_gaia_audio.py
2
+
3
+ import os
4
+ import re
5
+ import requests
6
+
7
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
8
+ OUT_PATH = "/mnt/data/test.wav"
9
+
10
+ def main():
11
+ # 1) Fetch GAIA questions
12
+ resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
13
+ resp.raise_for_status()
14
+ questions = resp.json()
15
+
16
+ # 2) Try attachments field first
17
+ for q in questions:
18
+ for field in ("attachments", "attachment", "audio"):
19
+ urls = q.get(field)
20
+ if not urls:
21
+ continue
22
+ if isinstance(urls, str):
23
+ urls = [urls]
24
+ for url in urls:
25
+ if is_media_url(url):
26
+ return download_audio(url)
27
+
28
+ # 3) Fallback: regex scan in question text
29
+ pattern = re.compile(r"(https?://\S+\.(?:mp3|wav))", re.IGNORECASE)
30
+ for q in questions:
31
+ text = q.get("question", "")
32
+ match = pattern.search(text)
33
+ if match:
34
+ url = match.group(1)
35
+ return download_audio(url)
36
+
37
+ print("⚠️ No .mp3/.wav URL found in GAIA payload; skipping download.")
38
+ return
39
+
40
+ def is_media_url(url: str) -> bool:
41
+ return bool(re.match(r"^https?://.*\.(?:mp3|wav)$", url, re.IGNORECASE))
42
+
43
+ def download_audio(url: str):
44
+ print(f"Downloading audio from {url}")
45
+ r = requests.get(url, timeout=30)
46
+ r.raise_for_status()
47
+
48
+ ext = os.path.splitext(url)[1].lower()
49
+ content = r.content
50
+
51
+ if ext == ".mp3":
52
+ # try to convert to wav if pydub installed
53
+ try:
54
+ from pydub import AudioSegment
55
+ mp3_path = "/mnt/data/tmp.mp3"
56
+ with open(mp3_path, "wb") as f:
57
+ f.write(content)
58
+ audio = AudioSegment.from_mp3(mp3_path)
59
+ audio.export(OUT_PATH, format="wav")
60
+ print(f"βœ” Saved WAV to {OUT_PATH}")
61
+ return
62
+ except ImportError:
63
+ # fallback: write raw mp3 bytes
64
+ OUT = OUT_PATH.replace(".wav", ".mp3")
65
+ with open(OUT, "wb") as f:
66
+ f.write(content)
67
+ print(f"⚠ pydub not installed; saved MP3 to {OUT}")
68
+ return
69
+
70
+ # if it's .wav or any other, write directly
71
+ with open(OUT_PATH, "wb") as f:
72
+ f.write(content)
73
+ print(f"βœ” Saved WAV to {OUT_PATH}")
74
+
75
+ if __name__ == "__main__":
76
+ main()
langgraph3.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import pandas as pd
4
+ import whisper
5
+
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, AnyMessage
8
+ from langchain_core.tools import tool
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader
11
+
12
+ from langgraph.graph import StateGraph, MessagesState, START, END
13
+ from langgraph.prebuilt import ToolNode, tools_condition
14
+
15
+ load_dotenv()
16
+
17
+ # ─────────────────────────────────────────────
18
+ # System prompt with placeholder for Excel summary
19
+ # ─────────────────────────────────────────────
20
+ SYSTEM_TEMPLATE = """
21
+ You are a razor‑sharp QA agent that answers in **one bare line**.
22
+ - Use tools if factual lookup, audio, or Excel data is needed.
23
+ - Excel data summary is available below.
24
+ - Numbers only for counts.
25
+ - Comma‑separated lists (alphabetize if asked).
26
+ - Codes (IOC, country, etc.) bare.
27
+ - Never apologize or explain.
28
+ Begin.
29
+
30
+ Excel summary:
31
+ {excel_summary}
32
+ """.strip()
33
+
34
+ # ─────────────────────────────────────────────
35
+ # TOOLS
36
+ # ─────────────────────────────────────────────
37
+
38
+ @tool
39
+ def web_search(query: str) -> dict:
40
+ """Search Tavily for a query and return up to 3 results."""
41
+ docs = TavilySearchResults(max_results=3).run(query)
42
+ return {"web_results": "\n".join(d["content"] for d in docs)}
43
+
44
+ @tool
45
+ def wiki_search(query: str) -> dict:
46
+ """Search Wikipedia for a query and return up to 2 pages."""
47
+ pages = WikipediaLoader(query=query, load_max_docs=2).load()
48
+ return {"wiki_results": "\n\n".join(p.page_content for p in pages)}
49
+
50
+ @tool
51
+ def transcribe_audio(path: str) -> dict:
52
+ """Given a local audio file path, return its transcript."""
53
+ model = whisper.load_model("base")
54
+ result = model.transcribe(path)
55
+ return {"transcript": result["text"]}
56
+
57
+ @tool
58
+ def read_excel(path: str, sheet_name: str = None, sample_rows: int = 5) -> dict:
59
+ """
60
+ Read Excel file and return a text summary:
61
+ - Columns
62
+ - Sample rows (up to sample_rows)
63
+ - Basic data types and row count
64
+ """
65
+ df = pd.read_excel(path, sheet_name=sheet_name or 0)
66
+ if isinstance(df, dict):
67
+ df = next(iter(df.values()))
68
+ sample = df.head(sample_rows)
69
+ summary_lines = [
70
+ f"Columns: {', '.join(df.columns)}",
71
+ "Data types: " + ", ".join(f"{col}: {dtype}" for col, dtype in df.dtypes.items()),
72
+ "Sample data:\n" + sample.to_csv(index=False),
73
+ f"Total rows: {len(df)}"
74
+ ]
75
+ return {"excel_summary": "\n".join(summary_lines)}
76
+
77
+ TOOLS = [web_search, wiki_search, transcribe_audio, read_excel]
78
+
79
+ # ─────────────────────────────────────────────
80
+ # Load Excel summary ONCE before building system prompt
81
+ # ─────────────────────────────────────────────
82
+ EXCEL_PATH = "test_sales.xlsx"
83
+ excel_summary = read_excel.invoke({"path": EXCEL_PATH})["excel_summary"]
84
+
85
+ # Build system message with injected Excel summary
86
+ SYSTEM = SystemMessage(content=SYSTEM_TEMPLATE.format(excel_summary=excel_summary))
87
+
88
+ # ─────────────────────────────────────────────
89
+ # LLM + GRAPH SETUP
90
+ # ─────────────────────────────────────────────
91
+
92
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
93
+ llm_with_tools = llm.bind_tools(TOOLS)
94
+
95
+ builder = StateGraph(MessagesState)
96
+
97
+ def assistant(state: dict) -> dict:
98
+ msgs = state.get("messages", [])
99
+ # Ensure system prompt is present at the start
100
+ if not msgs or not isinstance(msgs[0], SystemMessage):
101
+ msgs = [SYSTEM] + msgs
102
+
103
+ # Let LLM + tools framework handle tool invocation dynamically
104
+ out: AnyMessage = llm_with_tools.invoke(msgs)
105
+
106
+ if isinstance(out, AIMessage) and out.usage_metadata is None:
107
+ out.usage_metadata = {"input_tokens":0,"output_tokens":0,"total_tokens":0}
108
+ return {"messages": msgs + [out]}
109
+
110
+ builder.add_node("assistant", assistant)
111
+ builder.add_node("tools", ToolNode(TOOLS))
112
+
113
+ builder.add_edge(START, "assistant")
114
+ builder.add_conditional_edges(
115
+ "assistant",
116
+ tools_condition,
117
+ {"tools": "tools", END: END}
118
+ )
119
+ builder.add_edge("tools", "assistant")
120
+
121
+ graph = builder.compile()
122
+
123
+ # ─────────────────────────────────────────────
124
+ # Mermaid diagram
125
+ # ───���─────────────────────────────────────────
126
+ print("\nπŸ” Mermaid diagram:")
127
+ print(graph.get_graph().draw_mermaid())
128
+
129
+ # ─────────────────────────────────────────────
130
+ # Smoke test with multi-type questions
131
+ # ─────────────────────────────────────────────
132
+ if __name__ == "__main__":
133
+ print("πŸ”Ή Smoke-testing QA agent")
134
+ questions = [
135
+ "How much is 2 + 2?",
136
+ "What is the capital of France?",
137
+ "How many rows belong to the food category in the Excel file?",
138
+ "Which country had the fewest athletes at the 1928 Olympics? Give the IOC code."
139
+ ]
140
+ for q in questions:
141
+ res = graph.invoke({"messages": [HumanMessage(content=q)]})
142
+ ans = res['messages'][-1].content.strip().rstrip('.')
143
+ print(f"Q: {q}\n→ A: {ans!r}\n")
requirements.txt CHANGED
@@ -38,3 +38,5 @@ hf-xet~=1.1.1
38
  langchain-openai
39
  tenacity
40
  openai
 
 
 
38
  langchain-openai
39
  tenacity
40
  openai
41
+ openai-whisper
42
+ openpyxl