Spaces:
Sleeping
Sleeping
hugh2023
Add multi-modal agent system with media analysis, web scraping, and enhanced configuration management
adec1cb | import os | |
| import gradio as gr | |
| import requests | |
| import pandas as pd | |
| import json | |
| import base64 | |
| import io | |
| from typing import Dict, List, Any, Optional, Union | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import tempfile | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import pipeline, AutoProcessor, AutoModel | |
| # import moviepy.editor as mp # 暂时注释掉,需要安装moviepy | |
| # from pytube import YouTube # 暂时注释掉,需要安装pytube | |
| import urllib.request | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_core.tools import tool | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| # 环境变量设置 | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # 导入自定义模块 | |
| from config import Config | |
| from tools import ToolManager | |
| from prompts import get_answer_prompt, ERROR_ANSWER_TEMPLATE | |
| # 常量定义 | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| class AgentState: | |
| """智能体状态类""" | |
| question: str | |
| media_type: Optional[str] = None # 'image', 'video', 'text' | |
| media_path: Optional[str] = None | |
| extracted_info: Dict[str, Any] = None | |
| search_results: List[str] = None | |
| analysis_results: Dict[str, Any] = None | |
| workflow_plan: List[Dict[str, Any]] = None # 工作流计划 | |
| current_step: int = 0 # 当前执行步骤 | |
| final_answer: str = "" | |
| error: Optional[str] = None | |
| def __post_init__(self): | |
| if self.extracted_info is None: | |
| self.extracted_info = {} | |
| if self.search_results is None: | |
| self.search_results = [] | |
| if self.analysis_results is None: | |
| self.analysis_results = {} | |
| if self.workflow_plan is None: | |
| self.workflow_plan = [] | |
| class MediaAnalyzer: | |
| """媒体分析器类""" | |
| def __init__(self): | |
| # 初始化图像分析模型 | |
| self.image_processor = AutoProcessor.from_pretrained("microsoft/git-base") | |
| self.image_model = AutoModel.from_pretrained("microsoft/git-base") | |
| # 初始化图像描述模型 | |
| self.image_caption_pipeline = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| # 初始化图像分类模型 | |
| self.image_classification_pipeline = pipeline( | |
| "image-classification", | |
| model="microsoft/resnet-50", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| # 初始化对象检测模型 | |
| self.object_detection_pipeline = pipeline( | |
| "object-detection", | |
| model="facebook/detr-resnet-50", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("MediaAnalyzer initialized successfully") | |
| def analyze_image(self, image_path: str) -> Dict[str, Any]: | |
| """分析图像内容""" | |
| try: | |
| # 加载图像 | |
| image = Image.open(image_path) | |
| # 图像描述 | |
| caption_result = self.image_caption_pipeline(image) | |
| caption = caption_result[0]['generated_text'] | |
| # 图像分类 | |
| classification_result = self.image_classification_pipeline(image) | |
| top_classes = classification_result[:5] | |
| # 对象检测 | |
| detection_result = self.object_detection_pipeline(image) | |
| detected_objects = [] | |
| for detection in detection_result: | |
| detected_objects.append({ | |
| 'label': detection['label'], | |
| 'confidence': detection['score'], | |
| 'box': detection['box'] | |
| }) | |
| # 图像基本信息 | |
| image_info = { | |
| 'size': image.size, | |
| 'mode': image.mode, | |
| 'format': image.format | |
| } | |
| return { | |
| 'caption': caption, | |
| 'classification': top_classes, | |
| 'detected_objects': detected_objects, | |
| 'image_info': image_info | |
| } | |
| except Exception as e: | |
| return {'error': f"图像分析失败: {str(e)}"} | |
| def analyze_video(self, video_path: str) -> Dict[str, Any]: | |
| """分析视频内容 - 真正让VLLM看视频""" | |
| try: | |
| # 使用OpenCV分析视频 | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return {'error': "无法打开视频文件"} | |
| # 获取视频基本信息 | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| duration = frame_count / fps if fps > 0 else 0 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| print(f"🎬 开始分析视频: {frame_count}帧, {fps}fps, 时长{duration:.1f}秒") | |
| # 提取关键帧进行分析(每秒1帧) | |
| frames_analyzed = [] | |
| frame_interval = max(1, int(fps)) # 每秒1帧 | |
| for i in range(0, frame_count, frame_interval): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| ret, frame = cap.read() | |
| if ret: | |
| # 转换为PIL图像进行分析 | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(frame_rgb) | |
| # 使用VLLM分析图像 | |
| try: | |
| caption_result = self.image_caption_pipeline(pil_image) | |
| frame_info = { | |
| "frame_number": i, | |
| "timestamp": i / fps if fps > 0 else 0, | |
| "caption": caption_result[0]['generated_text'] | |
| } | |
| frames_analyzed.append(frame_info) | |
| print(f"📸 第{i//frame_interval}帧 ({i/fps:.1f}s): {frame_info['caption']}") | |
| except Exception as e: | |
| print(f"帧分析失败: {e}") | |
| frames_analyzed.append({ | |
| "frame_number": i, | |
| "timestamp": i / fps if fps > 0 else 0, | |
| "caption": "无法分析此帧" | |
| }) | |
| cap.release() | |
| # 生成视频内容总结 | |
| if frames_analyzed: | |
| # 提取所有描述 | |
| descriptions = [frame['caption'] for frame in frames_analyzed if frame['caption'] != "无法分析此帧"] | |
| if descriptions: | |
| # 使用LLM总结视频内容 | |
| summary_prompt = f""" | |
| 基于以下视频帧描述,总结这个视频的主要内容: | |
| {chr(10).join([f"时间 {frame['timestamp']:.1f}s: {frame['caption']}" for frame in frames_analyzed[:10]])} | |
| 请用中文总结这个视频的主要内容: | |
| """ | |
| try: | |
| from langchain_openai import ChatOpenAI | |
| llm = ChatOpenAI( | |
| model="gpt-3.5-turbo", | |
| temperature=0.7, | |
| api_key=Config.OPENAI_API_KEY | |
| ) | |
| summary_response = llm.invoke(summary_prompt) | |
| video_summary = summary_response.content | |
| except: | |
| video_summary = f"视频包含{len(frames_analyzed)}个场景,主要展示了各种视觉内容" | |
| else: | |
| video_summary = "无法分析视频内容" | |
| else: | |
| video_summary = "视频分析失败" | |
| return { | |
| 'type': 'video', | |
| 'video_info': { | |
| 'duration': duration, | |
| 'fps': fps, | |
| 'frame_count': frame_count, | |
| 'resolution': f"{width}x{height}" | |
| }, | |
| 'frames_analyzed': frames_analyzed[:10], # 只返回前10帧 | |
| 'video_summary': video_summary, | |
| 'analysis_method': 'OpenCV + VLLM', | |
| 'summary': f"视频时长{duration:.1f}秒,分析了{len(frames_analyzed)}个关键帧,内容:{video_summary}" | |
| } | |
| except Exception as e: | |
| return {'error': f"视频分析失败: {str(e)}"} | |
| def download_media(self, url: str, media_type: str) -> str: | |
| """下载媒体文件""" | |
| try: | |
| if media_type == 'video': | |
| # 简化版本:对于视频,只返回URL | |
| print("⚠️ 视频下载功能需要安装moviepy和pytube") | |
| return url | |
| else: | |
| # 下载图像文件 | |
| temp_path = tempfile.mktemp(suffix='.jpg') | |
| urllib.request.urlretrieve(url, temp_path) | |
| return temp_path | |
| except Exception as e: | |
| raise Exception(f"媒体下载失败: {str(e)}") | |
| class SearchEngine: | |
| """搜索引擎类""" | |
| def __init__(self): | |
| self.search_tool = DuckDuckGoSearchRun() | |
| def search(self, query: str) -> List[str]: | |
| """执行搜索""" | |
| try: | |
| results = self.search_tool.run(query) | |
| return [results] if isinstance(results, str) else results | |
| except Exception as e: | |
| return [f"搜索失败: {str(e)}"] | |
| class MultiModalAgent: | |
| """多模态智能体主类""" | |
| def __init__(self): | |
| # 验证配置 | |
| if not Config.validate(): | |
| raise ValueError("配置验证失败,请检查环境变量") | |
| self.media_analyzer = MediaAnalyzer() | |
| self.search_engine = SearchEngine() | |
| self.tool_manager = ToolManager() | |
| self.llm = ChatOpenAI( | |
| model=Config.OPENAI_MODEL, | |
| temperature=Config.OPENAI_TEMPERATURE, | |
| api_key=Config.OPENAI_API_KEY | |
| ) | |
| # 构建LangGraph工作流 | |
| self.workflow = self._build_workflow() | |
| print("MultiModalAgent initialized successfully") | |
| def _build_workflow(self) -> StateGraph: | |
| """构建LangGraph工作流""" | |
| # 创建状态图 | |
| workflow = StateGraph(AgentState) | |
| # 添加节点 | |
| workflow.add_node("plan_workflow", self._plan_workflow) | |
| workflow.add_node("classify_media", self._classify_media) | |
| workflow.add_node("analyze_media", self._analyze_media) | |
| workflow.add_node("search_info", self._search_info) | |
| workflow.add_node("use_tools", self._use_tools) | |
| workflow.add_node("synthesize_answer", self._synthesize_answer) | |
| # 设置入口点 | |
| workflow.set_entry_point("plan_workflow") | |
| # 添加边 | |
| workflow.add_edge("plan_workflow", "classify_media") | |
| workflow.add_edge("classify_media", "analyze_media") | |
| workflow.add_edge("analyze_media", "search_info") | |
| workflow.add_edge("search_info", "use_tools") | |
| workflow.add_edge("use_tools", "synthesize_answer") | |
| workflow.add_edge("synthesize_answer", END) | |
| return workflow.compile() | |
| def _plan_workflow(self, state: AgentState) -> AgentState: | |
| """智能规划工作流""" | |
| try: | |
| # 使用LLM分析任务并制定工作流计划 | |
| planning_prompt = f""" | |
| 你是一个智能工作流规划专家。请分析以下任务,并制定一个详细的工作流计划。 | |
| 任务: {state.question} | |
| 请根据任务类型和需求,设计一个合适的工作流。工作流应该包含以下信息: | |
| 1. 步骤编号 | |
| 2. 步骤名称 | |
| 3. 步骤描述 | |
| 4. 是否需要搜索网络 | |
| 5. 需要使用哪些工具 | |
| 6. 预期输出 | |
| 请以JSON格式返回工作流计划,格式如下: | |
| {{ | |
| "workflow": [ | |
| {{ | |
| "step": 1, | |
| "name": "步骤名称", | |
| "description": "步骤描述", | |
| "needs_search": true/false, | |
| "tools": ["工具1", "工具2"], | |
| "expected_output": "预期输出" | |
| }} | |
| ] | |
| }} | |
| 请确保工作流是合理的、高效的,并且能够完成任务。 | |
| """ | |
| # 调用LLM进行工作流规划 | |
| response = self.llm.invoke(planning_prompt) | |
| # 解析工作流计划 | |
| try: | |
| import json | |
| # 尝试从响应中提取JSON | |
| if "```json" in response.content: | |
| json_start = response.content.find("```json") + 7 | |
| json_end = response.content.find("```", json_start) | |
| json_str = response.content[json_start:json_end].strip() | |
| else: | |
| # 尝试直接解析 | |
| json_str = response.content.strip() | |
| workflow_data = json.loads(json_str) | |
| state.workflow_plan = workflow_data.get("workflow", []) | |
| print(f"🤖 工作流规划完成,共 {len(state.workflow_plan)} 个步骤:") | |
| for step in state.workflow_plan: | |
| print(f" 📋 步骤 {step.get('step', '?')}: {step.get('name', 'Unknown')}") | |
| print(f" {step.get('description', 'No description')}") | |
| if step.get('needs_search', False): | |
| print(f" 🔍 需要搜索: 是") | |
| if step.get('tools'): | |
| print(f" 🛠️ 工具: {', '.join(step['tools'])}") | |
| print() | |
| except json.JSONDecodeError: | |
| # 如果JSON解析失败,使用默认工作流 | |
| print("⚠️ 工作流规划解析失败,使用默认工作流") | |
| state.workflow_plan = [ | |
| { | |
| "step": 1, | |
| "name": "媒体分类", | |
| "description": "分析任务中的媒体类型", | |
| "needs_search": False, | |
| "tools": [], | |
| "expected_output": "确定媒体类型" | |
| }, | |
| { | |
| "step": 2, | |
| "name": "媒体分析", | |
| "description": "分析媒体内容", | |
| "needs_search": False, | |
| "tools": ["媒体分析工具"], | |
| "expected_output": "提取媒体信息" | |
| }, | |
| { | |
| "step": 3, | |
| "name": "信息搜索", | |
| "description": "搜索相关信息", | |
| "needs_search": True, | |
| "tools": ["搜索引擎"], | |
| "expected_output": "搜索结果" | |
| }, | |
| { | |
| "step": 4, | |
| "name": "工具使用", | |
| "description": "使用专业工具", | |
| "needs_search": False, | |
| "tools": ["各种专业工具"], | |
| "expected_output": "工具分析结果" | |
| }, | |
| { | |
| "step": 5, | |
| "name": "答案合成", | |
| "description": "综合所有信息生成答案", | |
| "needs_search": False, | |
| "tools": [], | |
| "expected_output": "最终答案" | |
| } | |
| ] | |
| except Exception as e: | |
| print(f"❌ 工作流规划失败: {e}") | |
| # 使用默认工作流 | |
| state.workflow_plan = [ | |
| { | |
| "step": 1, | |
| "name": "默认工作流", | |
| "description": "使用默认工作流处理任务", | |
| "needs_search": True, | |
| "tools": [], | |
| "expected_output": "任务完成" | |
| } | |
| ] | |
| return state | |
| def _classify_media(self, state: AgentState) -> AgentState: | |
| """分类媒体类型""" | |
| question = state.question.lower() | |
| # 提取URL | |
| import re | |
| url_pattern = r'https?://[^\s]+' | |
| urls = re.findall(url_pattern, state.question) | |
| # 检测媒体类型 | |
| if any(keyword in question for keyword in ['图片', '图像', 'image', 'photo', 'img']): | |
| state.media_type = 'image' | |
| elif any(keyword in question for keyword in ['视频', 'video', 'movie', 'clip']): | |
| state.media_type = 'video' | |
| elif any(keyword in question for keyword in ['pdf', '文档', 'document', '报告', 'report']): | |
| state.media_type = 'pdf' | |
| elif any(keyword in question for keyword in ['网页', '网站', 'webpage', 'website', 'url', 'http', 'https']): | |
| state.media_type = 'webpage' | |
| elif any(keyword in question for keyword in ['youtube', 'yt', '视频', 'video']) and 'youtube.com' in question.lower(): | |
| state.media_type = 'youtube' | |
| elif any(keyword in question for keyword in ['wikipedia', 'wiki', '维基', '百科']): | |
| state.media_type = 'wikipedia' | |
| else: | |
| state.media_type = 'text' | |
| # 设置媒体路径 | |
| if urls: | |
| state.media_path = urls[0] # 使用第一个URL | |
| else: | |
| state.media_path = None | |
| return state | |
| def _analyze_media(self, state: AgentState) -> AgentState: | |
| """分析媒体内容""" | |
| if state.media_type == 'image' and state.media_path: | |
| state.extracted_info = self.media_analyzer.analyze_image(state.media_path) | |
| elif state.media_type == 'video' and state.media_path: | |
| state.extracted_info = self.media_analyzer.analyze_video(state.media_path) | |
| elif state.media_type == 'pdf' and state.media_path: | |
| # PDF分析 | |
| pdf_info = self.tool_manager.execute_tool('analyze_pdf_structure', pdf_path=state.media_path) | |
| pdf_text = self.tool_manager.execute_tool('extract_text_from_pdf', pdf_path=state.media_path) | |
| state.extracted_info = { | |
| 'type': 'pdf', | |
| 'pdf_info': pdf_info, | |
| 'text_content': pdf_text[:2000] if len(pdf_text) > 2000 else pdf_text # 限制文本长度 | |
| } | |
| elif state.media_type == 'webpage' and state.media_path: | |
| # 网页分析 | |
| webpage_content = self.tool_manager.execute_tool('fetch_webpage_content', url=state.media_path) | |
| webpage_structure = self.tool_manager.execute_tool('analyze_webpage_structure', url=state.media_path) | |
| state.extracted_info = { | |
| 'type': 'webpage', | |
| 'webpage_content': webpage_content, | |
| 'webpage_structure': webpage_structure | |
| } | |
| elif state.media_type == 'youtube' and state.media_path: | |
| # YouTube分析 | |
| youtube_info = self.tool_manager.execute_tool('get_youtube_info', url=state.media_path) | |
| youtube_thumbnail = self.tool_manager.execute_tool('download_youtube_thumbnail', url=state.media_path) | |
| state.extracted_info = { | |
| 'type': 'youtube', | |
| 'youtube_info': youtube_info, | |
| 'thumbnail_path': youtube_thumbnail | |
| } | |
| elif state.media_type == 'wikipedia': | |
| # Wikipedia分析 - 从问题中提取搜索词 | |
| import re | |
| # 提取可能的Wikipedia页面标题 | |
| wiki_pattern = r'(?:wikipedia|wiki|维基|百科)\s*(?:关于|的|页面|词条)?\s*[::]\s*(.+)' | |
| match = re.search(wiki_pattern, state.question, re.IGNORECASE) | |
| if match: | |
| search_term = match.group(1).strip() | |
| else: | |
| # 如果没有明确格式,尝试提取关键词 | |
| words = state.question.split() | |
| search_term = ' '.join([w for w in words if w not in ['wikipedia', 'wiki', '维基', '百科', '的', '是', '什么', '关于']]) | |
| if search_term: | |
| # 搜索Wikipedia | |
| wiki_search = self.tool_manager.execute_tool('search_wikipedia', query=search_term, max_results=3) | |
| if wiki_search and not 'error' in wiki_search[0]: | |
| # 获取第一个结果的详细信息 | |
| first_result = wiki_search[0] | |
| wiki_page = self.tool_manager.execute_tool('get_wikipedia_page', title=first_result['title']) | |
| state.extracted_info = { | |
| 'type': 'wikipedia', | |
| 'search_term': search_term, | |
| 'search_results': wiki_search, | |
| 'page_content': wiki_page | |
| } | |
| else: | |
| state.extracted_info = { | |
| 'type': 'wikipedia', | |
| 'search_term': search_term, | |
| 'error': '未找到相关Wikipedia页面' | |
| } | |
| else: | |
| state.extracted_info = { | |
| 'type': 'wikipedia', | |
| 'error': '无法提取搜索词' | |
| } | |
| else: | |
| state.extracted_info = {'type': 'text', 'content': state.question} | |
| return state | |
| def _search_info(self, state: AgentState) -> AgentState: | |
| """智能搜索相关信息""" | |
| # 根据工作流计划决定是否搜索 | |
| should_search = False | |
| # 检查当前步骤是否需要搜索 | |
| if state.workflow_plan and state.current_step < len(state.workflow_plan): | |
| current_step_plan = state.workflow_plan[state.current_step] | |
| should_search = current_step_plan.get('needs_search', False) | |
| # 如果没有工作流计划,使用原来的逻辑 | |
| if not state.workflow_plan: | |
| should_search = self.tool_manager.should_use_search(state.question, {'extracted_info': state.extracted_info}) | |
| if should_search: | |
| print(f"🔍 执行搜索 (步骤 {state.current_step + 1})") | |
| # 构建搜索查询 | |
| search_query = state.question | |
| if state.extracted_info and 'caption' in state.extracted_info: | |
| search_query += f" {state.extracted_info['caption']}" | |
| state.search_results = self.search_engine.search(search_query) | |
| print(f"✅ 搜索完成,找到 {len(state.search_results)} 个结果") | |
| else: | |
| print(f"⏭️ 跳过搜索 (步骤 {state.current_step + 1})") | |
| # 不需要搜索,设置为空 | |
| state.search_results = [] | |
| # 更新当前步骤 | |
| state.current_step += 1 | |
| return state | |
| def _use_tools(self, state: AgentState) -> AgentState: | |
| """使用工具进行额外分析""" | |
| try: | |
| tool_results = {} | |
| # 根据工作流计划选择工具 | |
| current_tools = [] | |
| if state.workflow_plan and state.current_step < len(state.workflow_plan): | |
| current_step_plan = state.workflow_plan[state.current_step] | |
| current_tools = current_step_plan.get('tools', []) | |
| print(f"🛠️ 使用工具 (步骤 {state.current_step + 1}): {', '.join(current_tools) if current_tools else '无'}") | |
| # 如果没有工作流计划或工具列表为空,使用原来的逻辑 | |
| if not current_tools: | |
| question_lower = state.question.lower() | |
| # 代码分析工具 | |
| if any(keyword in question_lower for keyword in ['代码', 'code', 'python', '程序', 'program']): | |
| # 检查是否有代码内容 | |
| if '```python' in state.question or 'def ' in state.question or 'import ' in state.question: | |
| # 提取代码块 | |
| code_start = state.question.find('```python') | |
| if code_start != -1: | |
| code_end = state.question.find('```', code_start + 8) | |
| if code_end != -1: | |
| code = state.question[code_start + 8:code_end].strip() | |
| else: | |
| code = state.question[code_start + 8:].strip() | |
| else: | |
| # 尝试提取代码片段 | |
| lines = state.question.split('\n') | |
| code_lines = [] | |
| for line in lines: | |
| if line.strip().startswith(('def ', 'import ', 'class ', 'if ', 'for ', 'while ')): | |
| code_lines.append(line) | |
| code = '\n'.join(code_lines) | |
| if code.strip(): | |
| # 分析代码 | |
| tool_results['code_analysis'] = self.tool_manager.execute_tool( | |
| 'analyze_python_code', | |
| code=code | |
| ) | |
| # 解释代码 | |
| tool_results['code_explanation'] = self.tool_manager.execute_tool( | |
| 'explain_code', | |
| code=code | |
| ) | |
| # 如果需要执行代码 | |
| if any(keyword in question_lower for keyword in ['运行', '执行', 'execute', 'run']): | |
| tool_results['code_execution'] = self.tool_manager.execute_tool( | |
| 'execute_python_code', | |
| code=code | |
| ) | |
| # 视频内容分析 | |
| if state.media_type == 'video' and state.media_path: | |
| if any(keyword in question_lower for keyword in ['视频', 'video', '内容', 'content']): | |
| tool_results['video_analysis'] = self.tool_manager.execute_tool( | |
| 'analyze_video_content', | |
| video_path=state.media_path | |
| ) | |
| # PDF内容分析 | |
| if state.media_type == 'pdf' and state.media_path: | |
| if any(keyword in question_lower for keyword in ['pdf', '文档', 'document', '内容', 'content', '总结', 'summary']): | |
| tool_results['pdf_summary'] = self.tool_manager.execute_tool( | |
| 'summarize_pdf_content', | |
| pdf_path=state.media_path | |
| ) | |
| # PDF文本搜索 | |
| if any(keyword in question_lower for keyword in ['搜索', '查找', 'search', 'find']): | |
| # 尝试从问题中提取搜索词 | |
| search_terms = [] | |
| for word in question_lower.split(): | |
| if len(word) > 2 and word not in ['搜索', '查找', 'search', 'find', 'pdf', '文档']: | |
| search_terms.append(word) | |
| if search_terms: | |
| search_term = ' '.join(search_terms[:3]) # 最多3个词 | |
| tool_results['pdf_search'] = self.tool_manager.execute_tool( | |
| 'search_text_in_pdf', | |
| pdf_path=state.media_path, | |
| search_term=search_term | |
| ) | |
| # PDF图像提取 | |
| if any(keyword in question_lower for keyword in ['图像', '图片', 'image', '图', '图表']): | |
| tool_results['pdf_images'] = self.tool_manager.execute_tool( | |
| 'extract_images_from_pdf', | |
| pdf_path=state.media_path | |
| ) | |
| # 网页内容分析 | |
| if state.media_type == 'webpage' and state.media_path: | |
| if any(keyword in question_lower for keyword in ['网页', '网站', 'webpage', 'website', '内容', 'content', '总结', 'summary']): | |
| tool_results['webpage_summary'] = self.tool_manager.execute_tool( | |
| 'summarize_webpage_content', | |
| url=state.media_path | |
| ) | |
| # 网页文本搜索 | |
| if any(keyword in question_lower for keyword in ['搜索', '查找', 'search', 'find']): | |
| # 尝试从问题中提取搜索词 | |
| search_terms = [] | |
| for word in question_lower.split(): | |
| if len(word) > 2 and word not in ['搜索', '查找', 'search', 'find', '网页', '网站']: | |
| search_terms.append(word) | |
| if search_terms: | |
| search_term = ' '.join(search_terms[:3]) # 最多3个词 | |
| tool_results['webpage_search'] = self.tool_manager.execute_tool( | |
| 'search_content_in_webpage', | |
| url=state.media_path, | |
| search_term=search_term | |
| ) | |
| # 网页链接提取 | |
| if any(keyword in question_lower for keyword in ['链接', 'link', 'url', '地址']): | |
| tool_results['webpage_links'] = self.tool_manager.execute_tool( | |
| 'extract_links_from_webpage', | |
| url=state.media_path | |
| ) | |
| # 网页可访问性检查 | |
| if any(keyword in question_lower for keyword in ['可访问性', 'accessibility', '无障碍', '检查']): | |
| tool_results['webpage_accessibility'] = self.tool_manager.execute_tool( | |
| 'check_webpage_accessibility', | |
| url=state.media_path | |
| ) | |
| # YouTube内容分析 | |
| if state.media_type == 'youtube' and state.media_path: | |
| if any(keyword in question_lower for keyword in ['youtube', '视频', 'video', '内容', 'content', '信息', 'info']): | |
| # 获取YouTube信息已经在_analyze_media中完成 | |
| pass | |
| # YouTube视频下载 | |
| if any(keyword in question_lower for keyword in ['下载', 'download', '保存', 'save']): | |
| tool_results['youtube_download'] = self.tool_manager.execute_tool( | |
| 'download_youtube_video', | |
| url=state.media_path | |
| ) | |
| # YouTube音频提取 | |
| if any(keyword in question_lower for keyword in ['音频', 'audio', '声音', 'sound', '提取', 'extract']): | |
| tool_results['youtube_audio'] = self.tool_manager.execute_tool( | |
| 'extract_youtube_audio', | |
| url=state.media_path | |
| ) | |
| # YouTube评论分析 | |
| if any(keyword in question_lower for keyword in ['评论', 'comment', '反馈', 'feedback']): | |
| tool_results['youtube_comments'] = self.tool_manager.execute_tool( | |
| 'analyze_youtube_comments', | |
| url=state.media_path | |
| ) | |
| # Wikipedia内容分析 | |
| if state.media_type == 'wikipedia': | |
| if any(keyword in question_lower for keyword in ['wikipedia', 'wiki', '维基', '百科', '搜索', 'search']): | |
| # Wikipedia搜索已经在_analyze_media中完成 | |
| pass | |
| # Wikipedia页面分类 | |
| if any(keyword in question_lower for keyword in ['分类', 'category', '类别']): | |
| if state.extracted_info and 'page_content' in state.extracted_info and 'title' in state.extracted_info['page_content']: | |
| tool_results['wikipedia_categories'] = self.tool_manager.execute_tool( | |
| 'get_wikipedia_categories', | |
| title=state.extracted_info['page_content']['title'] | |
| ) | |
| # Wikipedia页面链接 | |
| if any(keyword in question_lower for keyword in ['链接', 'link', '相关', 'related']): | |
| if state.extracted_info and 'page_content' in state.extracted_info and 'title' in state.extracted_info['page_content']: | |
| tool_results['wikipedia_links'] = self.tool_manager.execute_tool( | |
| 'get_wikipedia_links', | |
| title=state.extracted_info['page_content']['title'] | |
| ) | |
| # Wikipedia搜索建议 | |
| if any(keyword in question_lower for keyword in ['建议', 'suggestion', '推荐', 'recommend']): | |
| if state.extracted_info and 'search_term' in state.extracted_info: | |
| tool_results['wikipedia_suggestions'] = self.tool_manager.execute_tool( | |
| 'get_wikipedia_suggestions', | |
| query=state.extracted_info['search_term'] | |
| ) | |
| # 英文Wikipedia搜索 | |
| if any(keyword in question_lower for keyword in ['英文', 'english', '英文版']): | |
| if state.extracted_info and 'search_term' in state.extracted_info: | |
| tool_results['wikipedia_english_search'] = self.tool_manager.execute_tool( | |
| 'search_wikipedia_english', | |
| query=state.extracted_info['search_term'] | |
| ) | |
| # 随机Wikipedia页面 | |
| if any(keyword in question_lower for keyword in ['随机', 'random', '随便', '任意']): | |
| tool_results['wikipedia_random'] = self.tool_manager.execute_tool( | |
| 'get_wikipedia_random_page' | |
| ) | |
| # 文本分析工具 | |
| if any(keyword in question_lower for keyword in ['情感', '情绪', 'sentiment', 'emotion']): | |
| if state.extracted_info and 'caption' in state.extracted_info: | |
| tool_results['sentiment'] = self.tool_manager.execute_tool( | |
| 'analyze_text_sentiment', | |
| text=state.extracted_info['caption'] | |
| ) | |
| # 关键词提取 | |
| if any(keyword in question_lower for keyword in ['关键词', '关键', 'keywords', 'key']): | |
| tool_results['keywords'] = self.tool_manager.execute_tool( | |
| 'extract_keywords', | |
| text=state.question | |
| ) | |
| # 文本摘要 | |
| if any(keyword in question_lower for keyword in ['摘要', '总结', 'summary', 'summarize']): | |
| if state.search_results: | |
| combined_text = " ".join(state.search_results) | |
| tool_results['summary'] = self.tool_manager.execute_tool( | |
| 'summarize_text', | |
| text=combined_text | |
| ) | |
| # 图像文本提取 | |
| if state.media_type == 'image' and state.media_path: | |
| if any(keyword in question_lower for keyword in ['文字', '文本', 'text', 'ocr']): | |
| tool_results['ocr_text'] = self.tool_manager.execute_tool( | |
| 'extract_text_from_image', | |
| image_path=state.media_path | |
| ) | |
| # 视频音频分析 | |
| if state.media_type == 'video' and state.media_path: | |
| if any(keyword in question_lower for keyword in ['音频', '声音', 'audio', 'sound']): | |
| tool_results['audio_info'] = self.tool_manager.execute_tool( | |
| 'extract_video_audio', | |
| video_path=state.media_path | |
| ) | |
| # 数学计算 | |
| if any(keyword in question_lower for keyword in ['计算', 'calculate', 'math', '数学']): | |
| # 尝试提取数学表达式 | |
| import re | |
| math_pattern = r'[\d\+\-\*\/\(\)\.\s]+' | |
| math_matches = re.findall(math_pattern, state.question) | |
| for match in math_matches: | |
| if len(match.strip()) > 3: # 至少3个字符 | |
| try: | |
| tool_results['math_calculation'] = self.tool_manager.execute_tool( | |
| 'calculate_math_expression', | |
| expression=match.strip() | |
| ) | |
| break | |
| except: | |
| continue | |
| # 翻译 | |
| if any(keyword in question_lower for keyword in ['翻译', 'translate']): | |
| # 提取需要翻译的文本 | |
| text_to_translate = state.question | |
| if '翻译' in text_to_translate: | |
| text_to_translate = text_to_translate.split('翻译')[-1].strip() | |
| elif 'translate' in text_to_translate: | |
| text_to_translate = text_to_translate.split('translate')[-1].strip() | |
| if text_to_translate and len(text_to_translate) > 2: | |
| tool_results['translation'] = self.tool_manager.execute_tool( | |
| 'translate_text', | |
| text=text_to_translate | |
| ) | |
| state.analysis_results = tool_results | |
| except Exception as e: | |
| state.error = f"工具使用失败: {str(e)}" | |
| state.analysis_results = {} | |
| return state | |
| def _synthesize_answer(self, state: AgentState) -> AgentState: | |
| """综合生成答案""" | |
| try: | |
| # 使用提示词函数生成提示 | |
| prompt = get_answer_prompt( | |
| question=state.question, | |
| media_analysis=json.dumps(state.extracted_info, ensure_ascii=False, indent=2), | |
| search_results=json.dumps(state.search_results, ensure_ascii=False, indent=2), | |
| tool_analysis=json.dumps(state.analysis_results, ensure_ascii=False, indent=2) | |
| ) | |
| # 使用LLM生成答案 | |
| response = self.llm.invoke([HumanMessage(content=prompt)]) | |
| state.final_answer = response.content | |
| except Exception as e: | |
| state.error = f"答案生成失败: {str(e)}" | |
| state.final_answer = ERROR_ANSWER_TEMPLATE | |
| return state | |
| def __call__(self, question: str, media_url: Optional[str] = None) -> str: | |
| """主调用方法""" | |
| try: | |
| # 初始化状态 | |
| state = AgentState(question=question) | |
| # 如果有媒体URL,下载并设置路径 | |
| if media_url: | |
| if any(ext in media_url.lower() for ext in ['.pdf']): | |
| media_type = 'pdf' | |
| state.media_path = self.tool_manager.execute_tool('download_pdf_from_url', url=media_url) | |
| elif 'youtube.com' in media_url.lower() or 'youtu.be' in media_url.lower(): | |
| media_type = 'youtube' | |
| state.media_path = media_url # 直接使用URL | |
| elif any(ext in media_url.lower() for ext in ['.mp4', '.avi', '.mov']): | |
| media_type = 'video' | |
| state.media_path = self.media_analyzer.download_media(media_url, media_type) | |
| elif any(ext in media_url.lower() for ext in ['http://', 'https://', 'www.']): | |
| media_type = 'webpage' | |
| state.media_path = media_url # 直接使用URL | |
| else: | |
| media_type = 'image' | |
| state.media_path = self.media_analyzer.download_media(media_url, media_type) | |
| state.media_type = media_type | |
| # 执行工作流 | |
| final_state = self.workflow.invoke(state) | |
| # LangGraph返回的是字典,因此使用键来访问 | |
| return final_state['final_answer'] | |
| except Exception as e: | |
| return f"智能体执行失败: {str(e)}" | |
| def run_and_submit_all(profile: gr.OAuthProfile | None): | |
| """运行评估并提交所有答案""" | |
| # 获取用户信息 | |
| if profile: | |
| username = f"{profile.username}" | |
| print(f"User logged in: {username}") | |
| else: | |
| print("User not logged in.") | |
| return "Please Login to Hugging Face with the button.", None | |
| space_id = os.getenv("SPACE_ID") | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| submit_url = f"{api_url}/submit" | |
| # 初始化多模态智能体 | |
| try: | |
| agent = MultiModalAgent() | |
| except Exception as e: | |
| print(f"Error instantiating agent: {e}") | |
| return f"Error initializing agent: {e}", None | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| print(agent_code) | |
| # 获取问题 | |
| print(f"Fetching questions from: {questions_url}") | |
| try: | |
| response = requests.get(questions_url, timeout=15) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| print("Fetched questions list is empty.") | |
| return "Fetched questions list is empty or invalid format.", None | |
| print(f"Fetched {len(questions_data)} questions.") | |
| except Exception as e: | |
| print(f"Error fetching questions: {e}") | |
| return f"Error fetching questions: {e}", None | |
| # 运行智能体 | |
| results_log = [] | |
| answers_payload = [] | |
| print(f"Running agent on {len(questions_data)} questions...") | |
| for item in questions_data: | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| if not task_id or question_text is None: | |
| print(f"Skipping item with missing task_id or question: {item}") | |
| continue | |
| try: | |
| submitted_answer = agent(question_text) | |
| answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
| results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}) | |
| except Exception as e: | |
| print(f"Error running agent on task {task_id}: {e}") | |
| results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"}) | |
| if not answers_payload: | |
| print("Agent did not produce any answers to submit.") | |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
| # 准备提交 | |
| submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
| status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..." | |
| print(status_update) | |
| # 提交答案 | |
| print(f"Submitting {len(answers_payload)} answers to: {submit_url}") | |
| try: | |
| response = requests.post(submit_url, json=submission_data, timeout=60) | |
| response.raise_for_status() | |
| result_data = response.json() | |
| final_status = ( | |
| f"Submission Successful!\n" | |
| f"User: {result_data.get('username')}\n" | |
| f"Overall Score: {result_data.get('score', 'N/A')}% " | |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
| f"Message: {result_data.get('message', 'No message received.')}" | |
| ) | |
| print("Submission successful.") | |
| results_df = pd.DataFrame(results_log) | |
| return final_status, results_df | |
| except Exception as e: | |
| status_message = f"Submission Failed: {e}" | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| def test_agent(question: str, media_url: str = ""): | |
| """测试智能体功能""" | |
| try: | |
| agent = MultiModalAgent() | |
| answer = agent(question, media_url if media_url else None) | |
| return answer | |
| except Exception as e: | |
| return f"测试失败: {str(e)}" | |
| # 构建Gradio界面 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 多模态智能体系统") | |
| gr.Markdown( | |
| """ | |
| **功能特性:** | |
| - 🎥 视频理解与分析 | |
| - 🖼️ 图像识别与描述 | |
| - 🔍 智能搜索引擎 | |
| - 🤖 LangGraph工作流编排 | |
| - 🧠 多模态信息融合 | |
| **使用说明:** | |
| 1. 登录你的Hugging Face账户 | |
| 2. 在测试区域输入问题(可选媒体URL) | |
| 3. 点击"运行评估"进行批量测试 | |
| """ | |
| ) | |
| gr.LoginButton() | |
| with gr.Tab("智能体测试"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| test_question = gr.Textbox(label="问题", placeholder="请输入你的问题...") | |
| test_media_url = gr.Textbox(label="媒体URL(可选)", placeholder="图片或视频URL...") | |
| test_button = gr.Button("测试智能体") | |
| with gr.Column(): | |
| test_output = gr.Textbox(label="智能体回答", lines=10) | |
| test_button.click( | |
| fn=test_agent, | |
| inputs=[test_question, test_media_url], | |
| outputs=test_output | |
| ) | |
| with gr.Tab("批量评估"): | |
| run_button = gr.Button("运行评估 & 提交所有答案") | |
| status_output = gr.Textbox(label="运行状态 / 提交结果", lines=5, interactive=False) | |
| results_table = gr.DataFrame(label="问题和智能体答案", wrap=True) | |
| run_button.click( | |
| fn=run_and_submit_all, | |
| outputs=[status_output, results_table] | |
| ) | |
| if __name__ == "__main__": | |
| print("\n" + "-"*30 + " 多模态智能体系统启动 " + "-"*30) | |
| space_host_startup = os.getenv("SPACE_HOST") | |
| space_id_startup = os.getenv("SPACE_ID") | |
| if space_host_startup: | |
| print(f"✅ SPACE_HOST found: {space_host_startup}") | |
| print(f" Runtime URL: https://{space_host_startup}.hf.space") | |
| else: | |
| print("ℹ️ SPACE_HOST environment variable not found (running locally?).") | |
| if space_id_startup: | |
| print(f"✅ SPACE_ID found: {space_id_startup}") | |
| print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") | |
| else: | |
| print("ℹ️ SPACE_ID environment variable not found (running locally?).") | |
| print("-"*(60 + len(" 多模态智能体系统启动 ")) + "\n") | |
| print("启动多模态智能体系统...") | |
| demo.launch(debug=True, share=False) |