akrstova commited on
Commit
57a3c14
·
1 Parent(s): bb4ec09

Replace whisper with transformers pipeline

Browse files
Files changed (5) hide show
  1. agent.py +1 -1
  2. pyproject.toml +1 -0
  3. requirements.txt +1 -0
  4. tools/file_tools.py +12 -28
  5. uv.lock +2 -0
agent.py CHANGED
@@ -91,7 +91,7 @@ def build_graph():
91
 
92
 
93
  if __name__ == "__main__":
94
- question = "On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?"
95
  # Build the graph
96
  graph = build_graph()
97
  # Run the graph
 
91
 
92
 
93
  if __name__ == "__main__":
94
+ question = "Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M.? Give only the first name."
95
  # Build the graph
96
  graph = build_graph()
97
  # Run the graph
pyproject.toml CHANGED
@@ -13,6 +13,7 @@ dependencies = [
13
  "langchain-google-genai>=2.1.4",
14
  "langchain-huggingface>=0.2.0",
15
  "langgraph>=0.4.3",
 
16
  "openai-whisper>=20240930",
17
  "pandas>=2.2.3",
18
  "pytesseract>=0.3.13",
 
13
  "langchain-google-genai>=2.1.4",
14
  "langchain-huggingface>=0.2.0",
15
  "langgraph>=0.4.3",
16
+ "numpy>=2.2.5",
17
  "openai-whisper>=20240930",
18
  "pandas>=2.2.3",
19
  "pytesseract>=0.3.13",
requirements.txt CHANGED
@@ -20,3 +20,4 @@ pgvector
20
  python-dotenv
21
  openai-whisper
22
  pytesseract
 
 
20
  python-dotenv
21
  openai-whisper
22
  pytesseract
23
+ transformers
tools/file_tools.py CHANGED
@@ -9,14 +9,11 @@ import contextlib
9
  from langchain_core.tools import tool
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
  import requests
12
- import whisper
13
  from PIL import Image
14
  import pytesseract
 
15
 
16
 
17
- # Load Whisper model once
18
- whisper_model = whisper.load_model("base") # or "small", "medium", "large"
19
-
20
  @tool
21
  def analyze_excel_file(file_path: str, query: str) -> str:
22
  """
@@ -46,41 +43,28 @@ def analyze_excel_file(file_path: str, query: str) -> str:
46
 
47
 
48
 
 
 
 
49
  @tool
50
- def process_mp3_file(file_path: str, query: str) -> str:
51
  """
52
- Transcribes an mp3 file and answers a question about its content.
53
 
54
  Args:
55
- file_path (str): The path to the .mp3 file
56
- query (str): The question to ask about the transcript
57
 
58
  Returns:
59
- str: The answer to the query based on audio content
60
  """
61
  try:
62
  print(f"Transcribing: {file_path}")
63
- # Whisper automatically handles MP3 input
64
- result = whisper_model.transcribe(file_path)
65
  transcript = result["text"]
66
-
67
- if not transcript.strip():
68
- return "Could not extract any meaningful text from the audio."
69
-
70
- # Ask question about transcript using Gemini
71
- llm = ChatGoogleGenerativeAI(
72
- model="gemini-2.0-flash-001",
73
- temperature=0.7,
74
- max_tokens=None,
75
- google_api_key=os.getenv("GOOGLE_API_KEY"),
76
- )
77
-
78
- prompt = f"Transcript:\n{transcript}\n\nQuestion: {query}\nAnswer only based on the transcript above."
79
- response = llm.invoke(prompt)
80
- return response.content
81
-
82
  except Exception as e:
83
- return f"Error processing mp3 file: {str(e)}"
84
 
85
 
86
 
 
9
  from langchain_core.tools import tool
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
  import requests
 
12
  from PIL import Image
13
  import pytesseract
14
+ from transformers import pipeline
15
 
16
 
 
 
 
17
  @tool
18
  def analyze_excel_file(file_path: str, query: str) -> str:
19
  """
 
43
 
44
 
45
 
46
+ # Load ASR pipeline once at module level (for efficiency)
47
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=-1)
48
+
49
  @tool
50
+ def transcribe_audio(file_path: str, query: str = "") -> str:
51
  """
52
+ Transcribes speech from an audio file (e.g., .mp3 or .wav).
53
 
54
  Args:
55
+ file_path (str): Path to the audio file.
56
+ query (str): (Optional) Ignored; present to support LangChain tool schema.
57
 
58
  Returns:
59
+ str: Transcribed text from the audio.
60
  """
61
  try:
62
  print(f"Transcribing: {file_path}")
63
+ result = asr_pipeline(file_path)
 
64
  transcript = result["text"]
65
+ return transcript.strip() if transcript.strip() else "No speech detected."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  except Exception as e:
67
+ return f"Error transcribing audio: {str(e)}"
68
 
69
 
70
 
uv.lock CHANGED
@@ -368,6 +368,7 @@ dependencies = [
368
  { name = "langchain-google-genai" },
369
  { name = "langchain-huggingface" },
370
  { name = "langgraph" },
 
371
  { name = "openai-whisper" },
372
  { name = "pandas" },
373
  { name = "pytesseract" },
@@ -387,6 +388,7 @@ requires-dist = [
387
  { name = "langchain-google-genai", specifier = ">=2.1.4" },
388
  { name = "langchain-huggingface", specifier = ">=0.2.0" },
389
  { name = "langgraph", specifier = ">=0.4.3" },
 
390
  { name = "openai-whisper", specifier = ">=20240930" },
391
  { name = "pandas", specifier = ">=2.2.3" },
392
  { name = "pytesseract", specifier = ">=0.3.13" },
 
368
  { name = "langchain-google-genai" },
369
  { name = "langchain-huggingface" },
370
  { name = "langgraph" },
371
+ { name = "numpy" },
372
  { name = "openai-whisper" },
373
  { name = "pandas" },
374
  { name = "pytesseract" },
 
388
  { name = "langchain-google-genai", specifier = ">=2.1.4" },
389
  { name = "langchain-huggingface", specifier = ">=0.2.0" },
390
  { name = "langgraph", specifier = ">=0.4.3" },
391
+ { name = "numpy", specifier = ">=2.2.5" },
392
  { name = "openai-whisper", specifier = ">=20240930" },
393
  { name = "pandas", specifier = ">=2.2.3" },
394
  { name = "pytesseract", specifier = ">=0.3.13" },