GAIA / extension_tools.py
hapda12's picture
Upload 12 files
358eb7e verified
"""
扩展工具模块 - GAIA Agent 扩展功能
包含:parse_pdf, parse_excel, image_ocr, transcribe_audio
注意:这些工具需要额外的依赖库,如果导入失败会优雅降级。
"""
import os
from typing import Optional, List
from langchain_core.tools import tool
from config import MAX_FILE_SIZE, TOOL_TIMEOUT
# ========================================
# PDF 解析工具
# ========================================
@tool
def parse_pdf(file_path: str, page_numbers: str = "all") -> str:
"""
解析 PDF 文件,提取文本内容。
Args:
file_path: PDF 文件路径
page_numbers: 页码范围
- "all": 所有页面
- "1": 第 1 页
- "1-5": 第 1 到 5 页
- "1,3,5": 第 1、3、5 页
Returns:
PDF 文本内容
限制:
- 扫描版 PDF 需配合 OCR
- 复杂排版可能顺序错乱
"""
try:
import pdfplumber
except ImportError:
return "PDF 解析不可用:请安装 pdfplumber 库 (pip install pdfplumber)"
if not os.path.exists(file_path):
return f"文件不存在: {file_path}"
if not file_path.lower().endswith('.pdf'):
return f"不是 PDF 文件: {file_path}"
try:
with pdfplumber.open(file_path) as pdf:
total_pages = len(pdf.pages)
# 解析页码范围
if page_numbers == "all":
pages_to_read = range(total_pages)
elif "-" in page_numbers:
start, end = map(int, page_numbers.split("-"))
pages_to_read = range(start - 1, min(end, total_pages))
elif "," in page_numbers:
pages_to_read = [int(p) - 1 for p in page_numbers.split(",")]
pages_to_read = [p for p in pages_to_read if 0 <= p < total_pages]
else:
page_num = int(page_numbers) - 1
if 0 <= page_num < total_pages:
pages_to_read = [page_num]
else:
return f"页码超出范围,PDF 共有 {total_pages} 页"
# 提取文本
text_parts = []
for i in pages_to_read:
page = pdf.pages[i]
text = page.extract_text()
if text:
text_parts.append(f"--- 第 {i + 1} 页 ---\n{text}")
if not text_parts:
return "PDF 中没有提取到文本内容(可能是扫描版,请尝试使用 OCR)"
result = "\n\n".join(text_parts)
# 限制长度
if len(result) > MAX_FILE_SIZE:
return result[:MAX_FILE_SIZE] + f"\n\n... [内容已截断,共 {len(result)} 字符]"
return result
except Exception as e:
return f"PDF 解析出错: {type(e).__name__}: {str(e)}"
# ========================================
# Excel 解析工具
# ========================================
@tool
def parse_excel(file_path: str, sheet_name: str = None, max_rows: int = 100) -> str:
"""
解析 Excel 文件内容。
Args:
file_path: Excel 文件路径(.xlsx, .xls)
sheet_name: 工作表名称,默认第一个
max_rows: 最大读取行数,默认 100
Returns:
表格内容(Markdown 格式)
"""
try:
import pandas as pd
except ImportError:
return "Excel 解析不可用:请安装 pandas 和 openpyxl 库"
if not os.path.exists(file_path):
return f"文件不存在: {file_path}"
try:
# 读取 Excel
if sheet_name:
df = pd.read_excel(file_path, sheet_name=sheet_name, nrows=max_rows)
else:
df = pd.read_excel(file_path, nrows=max_rows)
# 获取工作表信息
excel_file = pd.ExcelFile(file_path)
sheet_names = excel_file.sheet_names
# 构建输出
output = []
output.append(f"工作表: {sheet_names}")
output.append(f"当前读取: {sheet_name or sheet_names[0]}")
output.append(f"数据形状: {df.shape[0]} 行 x {df.shape[1]} 列")
output.append("")
# 转换为 Markdown 表格
output.append(df.to_markdown(index=False))
result = "\n".join(output)
# 限制长度
if len(result) > MAX_FILE_SIZE:
return result[:MAX_FILE_SIZE] + f"\n\n... [内容已截断]"
return result
except Exception as e:
return f"Excel 解析出错: {type(e).__name__}: {str(e)}"
# ========================================
# 图片 OCR 工具
# ========================================
@tool
def image_ocr(file_path: str, language: str = "eng") -> str:
"""
对图片进行 OCR 文字识别。
Args:
file_path: 图片路径(png/jpg/jpeg/bmp/gif/tiff)
language: 识别语言
- "eng": 英文
- "chi_sim": 简体中文
- "chi_tra": 繁体中文
- "eng+chi_sim": 多语言
Returns:
识别出的文字
注意:
需要安装 Tesseract OCR 引擎
"""
try:
import pytesseract
from PIL import Image
except ImportError:
return "OCR 不可用:请安装 pytesseract 和 Pillow 库"
if not os.path.exists(file_path):
return f"文件不存在: {file_path}"
# 检查文件格式
valid_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.tif'}
ext = os.path.splitext(file_path)[1].lower()
if ext not in valid_extensions:
return f"不支持的图片格式: {ext},支持: {', '.join(valid_extensions)}"
try:
# 打开图片
image = Image.open(file_path)
# 执行 OCR
text = pytesseract.image_to_string(image, lang=language)
if not text.strip():
return "图片中没有识别到文字内容"
# 清理文本
text = text.strip()
# 限制长度
if len(text) > MAX_FILE_SIZE:
return text[:MAX_FILE_SIZE] + f"\n\n... [内容已截断]"
return text
except pytesseract.TesseractNotFoundError:
return "OCR 引擎未安装:请安装 Tesseract OCR (https://github.com/tesseract-ocr/tesseract)"
except Exception as e:
return f"OCR 识别出错: {type(e).__name__}: {str(e)}"
# ========================================
# 音频转写工具
# ========================================
@tool
def transcribe_audio(file_path: str, language: str = "auto") -> str:
"""
将音频文件转写为文字。
使用 OpenAI Whisper 模型进行转写。
Args:
file_path: 音频路径(mp3/wav/m4a/ogg/flac)
language: 语言代码
- "auto": 自动检测
- "en": 英文
- "zh": 中文
- "ja": 日文
等等
Returns:
转写的文字内容
"""
try:
import whisper
except ImportError:
return "音频转写不可用:请安装 openai-whisper 库 (pip install openai-whisper)"
if not os.path.exists(file_path):
return f"文件不存在: {file_path}"
# 检查文件格式
valid_extensions = {'.mp3', '.wav', '.m4a', '.ogg', '.flac', '.wma', '.aac'}
ext = os.path.splitext(file_path)[1].lower()
if ext not in valid_extensions:
return f"不支持的音频格式: {ext},支持: {', '.join(valid_extensions)}"
try:
# 加载模型(使用 base 模型平衡速度和准确性)
model = whisper.load_model("base")
# 转写配置
options = {}
if language != "auto":
options["language"] = language
# 执行转写
result = model.transcribe(file_path, **options)
text = result.get("text", "").strip()
if not text:
return "音频中没有识别到语音内容"
# 添加语言检测信息
detected_lang = result.get("language", "unknown")
output = f"[检测到语言: {detected_lang}]\n\n{text}"
# 限制长度
if len(output) > MAX_FILE_SIZE:
return output[:MAX_FILE_SIZE] + f"\n\n... [内容已截断]"
return output
except Exception as e:
return f"音频转写出错: {type(e).__name__}: {str(e)}"
# ========================================
# 视觉分析工具(可选,基于多模态 LLM)
# ========================================
@tool
def analyze_image(file_path: str, question: str = "请描述这张图片的内容") -> str:
"""
使用多模态 LLM 分析图片内容。
适用于:
- 图片内容描述
- 图表数据提取
- 图片中的文字识别(比 OCR 更智能)
Args:
file_path: 图片路径
question: 关于图片的问题
Returns:
LLM 对图片的分析结果
"""
try:
import base64
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from config import OPENAI_BASE_URL, OPENAI_API_KEY, MODEL
except ImportError:
return "图片分析不可用:缺少必要的依赖"
if not os.path.exists(file_path):
return f"文件不存在: {file_path}"
try:
# 读取图片并编码
with open(file_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
# 检测图片格式
ext = os.path.splitext(file_path)[1].lower()
mime_types = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.webp': 'image/webp',
}
mime_type = mime_types.get(ext, 'image/png')
# 构建多模态消息
message = HumanMessage(
content=[
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{image_data}"
}
}
]
)
# 调用 LLM(添加超时保护)
llm = ChatOpenAI(
model=MODEL,
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
timeout=60, # 60秒超时
max_retries=1,
)
response = llm.invoke([message])
return response.content
except Exception as e:
return f"图片分析出错: {type(e).__name__}: {str(e)}"
# ========================================
# 导出扩展工具列表
# ========================================
EXTENSION_TOOLS = [
parse_pdf,
parse_excel,
image_ocr,
transcribe_audio,
analyze_image,
]