Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| from pathlib import Path | |
| import yaml | |
| import re | |
| import logging | |
| import io | |
| import sys | |
| import re | |
| from datetime import datetime, timezone, timedelta | |
| import requests | |
| from tools import * #gege的多模态 | |
| CONFIG = None | |
| HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml" | |
| def load_hf_config(): | |
| """加载YAML配置文件""" | |
| global CONFIG | |
| if CONFIG is None: | |
| try: | |
| with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f: | |
| CONFIG = yaml.safe_load(f) | |
| print(f"✅ 配置已加载: {HF_CONFIG_PATH}") | |
| except FileNotFoundError: | |
| print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。") | |
| CONFIG = {} | |
| except Exception as e: | |
| print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}") | |
| CONFIG = {} | |
| return CONFIG | |
| def get_hf_openai_config(): | |
| """获取OpenAI API配置""" | |
| config = load_hf_config() | |
| return config.get('openai', {}) | |
| def get_hf_openai_filter_config(): | |
| """获取Filter API配置""" | |
| config = load_hf_config() | |
| return config.get('openai_filter', {}) | |
| def get_hf_xunfei_config(): | |
| """获取讯飞API配置""" | |
| config = load_hf_config() | |
| return config.get('xunfei', {}) | |
| def get_hf_paths_config(): | |
| """获取文件路径配置""" | |
| config = load_hf_config() | |
| base = Path(__file__).resolve().parent | |
| paths_cfg = config.get('paths', {}) | |
| return { | |
| 'base_dir': base, | |
| 'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'), | |
| 'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'), | |
| 'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'), | |
| } | |
| llm_config = get_hf_openai_config() | |
| NVIDIA_API_BASE_URL = llm_config.get('base_url') | |
| NVIDIA_API_KEY = llm_config.get('api_key') | |
| NVIDIA_MODEL_NAME = llm_config.get('model') | |
| filter_config = get_hf_openai_filter_config() | |
| Filter_API_BASE_URL = filter_config.get('base_url_filter') | |
| Filter_API_KEY = filter_config.get('api_key_filter') | |
| Filter_MODEL_NAME = filter_config.get('model_filter') | |
| if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME: | |
| print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。") | |
| NVIDIA_API_BASE_URL = "" | |
| NVIDIA_API_KEY = "" | |
| NVIDIA_MODEL_NAME = "" | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。") | |
| Filter_API_BASE_URL = "" | |
| Filter_API_KEY = "" | |
| Filter_MODEL_NAME = "" | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def load_single_few_shot_file_hf(file_path: Path) -> str: | |
| """加载单个few-shot示例文件并转义大括号""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| escaped_content = content.replace('{', '{{').replace('}', '}}') | |
| return escaped_content | |
| except FileNotFoundError: | |
| return "" | |
| except Exception: | |
| return "" | |
| PROMPT_TEMPLATE_CONTENT = "" | |
| TRUE_POSITIVE_EXAMPLES_CONTENT = "" | |
| FALSE_POSITIVE_EXAMPLES_CONTENT = "" | |
| def load_prompt_data_hf(): | |
| """加载提示词模板和示例数据""" | |
| global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT | |
| paths = get_hf_paths_config() | |
| try: | |
| with open(paths['prompt_template'], 'r', encoding='utf-8') as f: | |
| PROMPT_TEMPLATE_CONTENT = f.read() | |
| except FileNotFoundError: | |
| PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found." | |
| TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples']) | |
| FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples']) | |
| load_prompt_data_hf() | |
| def _process_parsed_json(parsed_data): | |
| """处理解析后的JSON数据,确保格式正确""" | |
| try: | |
| if isinstance(parsed_data, list): | |
| if not parsed_data: | |
| return [{}] | |
| processed_list = [] | |
| for item in parsed_data: | |
| if isinstance(item, dict): | |
| processed_list.append(item) | |
| else: | |
| try: | |
| processed_list.append({"content": str(item)}) | |
| except: | |
| processed_list.append({"content": "无法转换的项目"}) | |
| if not processed_list: | |
| return [{}] | |
| return processed_list | |
| elif isinstance(parsed_data, dict): | |
| return parsed_data | |
| else: | |
| return {"content": str(parsed_data)} | |
| except Exception as e: | |
| return {"error": f"Error processing parsed JSON: {e}"} | |
| def json_parser(text: str) -> dict: | |
| """从文本中解析JSON数据,支持多种格式""" | |
| try: | |
| try: | |
| parsed_data = json.loads(text) | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError: | |
| pass | |
| match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL) | |
| if match: | |
| json_str = match.group(1).strip() | |
| json_str = re.sub(r',\s*]', ']', json_str) | |
| json_str = re.sub(r',\s*}', '}', json_str) | |
| try: | |
| parsed_data = json.loads(json_str) | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError: | |
| pass | |
| array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL) | |
| if array_match: | |
| potential_json = array_match.group(0).strip() | |
| try: | |
| parsed_data = json.loads(potential_json) | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError: | |
| pass | |
| object_match = re.search(r'\{.*?\}', text, re.DOTALL) | |
| if object_match: | |
| potential_json = object_match.group(0).strip() | |
| try: | |
| parsed_data = json.loads(potential_json) | |
| return _process_parsed_json(parsed_data) | |
| except json.JSONDecodeError: | |
| pass | |
| return {"error": "No valid JSON block found or failed to parse", "raw_text": text} | |
| except Exception as e: | |
| return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text} | |
| def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"): | |
| """使用LLM对消息进行分类过滤""" | |
| mock_data = [(text_input, message_id)] | |
| system_prompt = """ | |
| # 角色 | |
| 你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。 | |
| # 任务 | |
| 对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。 | |
| 主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略 | |
| # 要求 | |
| 1. 以json格式输出 | |
| 2. content简洁提炼关键词,字符数<20以内 | |
| 3. 输入条数和输出条数完全一样 | |
| # 输出示例 | |
| ``` | |
| [ | |
| {"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"}, | |
| {"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"} | |
| ] | |
| ``` | |
| """ | |
| llm_messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": str(mock_data)} | |
| ] | |
| try: | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| return [{"error": "Filter API configuration incomplete", "-": "-"}] | |
| headers = { | |
| "Authorization": f"Bearer {Filter_API_KEY}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "model": Filter_MODEL_NAME, | |
| "messages": llm_messages, | |
| "temperature": 0.0, | |
| "top_p": 0.95, | |
| "max_tokens": 1024, | |
| "stream": False | |
| } | |
| api_url = f"{Filter_API_BASE_URL}/chat/completions" | |
| try: | |
| response = requests.post(api_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| raw_llm_response = response.json()["choices"][0]["message"]["content"] | |
| except requests.exceptions.RequestException as e: | |
| return [{"error": f"Filter API call failed: {e}", "-": "-"}] | |
| raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "") | |
| parsed_filter_data = json_parser(raw_llm_response) | |
| if "error" in parsed_filter_data: | |
| return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}] | |
| if isinstance(parsed_filter_data, list) and parsed_filter_data: | |
| for item in parsed_filter_data: | |
| if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""): | |
| item["分类"] = "其他" | |
| request_id_list = {message_id} | |
| response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)} | |
| diff = request_id_list - response_id_list | |
| if diff: | |
| for missed_id in diff: | |
| parsed_filter_data.append({ | |
| "message_id": missed_id, | |
| "content": text_input[:20], | |
| "物流取件": 0, | |
| "欠费缴纳": 0, | |
| "待付(还)款": 0, | |
| "会议邀约": 0, | |
| "其他": 100, | |
| "分类": "其他" | |
| }) | |
| return parsed_filter_data | |
| else: | |
| return [{ | |
| "message_id": message_id, | |
| "content": text_input[:20], | |
| "物流取件": 0, | |
| "欠费缴纳": 0, | |
| "待付(还)款": 0, | |
| "会议邀约": 0, | |
| "其他": 100, | |
| "分类": "其他", | |
| "error": "Filter LLM returned empty or unexpected format" | |
| }] | |
| except Exception as e: | |
| return [{ | |
| "message_id": message_id, | |
| "content": text_input[:20], | |
| "物流取件": 0, | |
| "欠费缴纳": 0, | |
| "待付(还)款": 0, | |
| "会议邀约": 0, | |
| "其他": 100, | |
| "分类": "其他", | |
| "error": f"Filter LLM call/parse error: {str(e)}" | |
| }] | |
| def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"): | |
| """从文本生成待办事项列表""" | |
| if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT: | |
| return [["error", "Prompt template not loaded", "-"]] | |
| current_time_iso = datetime.now(timezone.utc).isoformat() | |
| content_escaped = text_input.replace('{', '{{').replace('}', '}}') | |
| formatted_prompt = PROMPT_TEMPLATE_CONTENT.format( | |
| true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT, | |
| false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT, | |
| current_time=current_time_iso, | |
| message_id=message_id, | |
| content_escaped=content_escaped | |
| ) | |
| enhanced_prompt = formatted_prompt + """ | |
| # 重要提示 | |
| 请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。 | |
| 你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。 | |
| """ | |
| llm_messages = [ | |
| {"role": "user", "content": enhanced_prompt} | |
| ] | |
| try: | |
| if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input): | |
| todo_item = { | |
| message_id: { | |
| "is_todo": True, | |
| "end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(), | |
| "location": "线上:中国移动APP", | |
| "todo_content": "缴纳话费", | |
| "urgency": "important" | |
| } | |
| } | |
| todo_content = "缴纳话费" | |
| end_time = todo_item[message_id]["end_time"].split("T")[0] | |
| location = todo_item[message_id]["location"] | |
| combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, "重要"]) | |
| return output_for_df | |
| elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input): | |
| meeting_time = None | |
| meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})' | |
| meeting_match = re.search(meeting_pattern, text_input) | |
| if meeting_match: | |
| meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat() | |
| else: | |
| meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat() | |
| todo_item = { | |
| message_id: { | |
| "is_todo": True, | |
| "end_time": meeting_time, | |
| "location": "线上:会议软件", | |
| "todo_content": "参加会议", | |
| "urgency": "important" | |
| } | |
| } | |
| todo_content = "参加会议" | |
| end_time = todo_item[message_id]["end_time"].split("T")[0] | |
| location = todo_item[message_id]["location"] | |
| combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, "重要"]) | |
| return output_for_df | |
| elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input): | |
| pickup_code = None | |
| code_pattern = r'取件码[是为:]?\s*(\d{4,6})' | |
| code_match = re.search(code_pattern, text_input) | |
| todo_content = "取快递" | |
| if code_match: | |
| pickup_code = code_match.group(1) | |
| todo_content = f"取快递(取件码:{pickup_code})" | |
| todo_item = { | |
| message_id: { | |
| "is_todo": True, | |
| "end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(), | |
| "location": "线下:快递柜", | |
| "todo_content": todo_content, | |
| "urgency": "important" | |
| } | |
| } | |
| end_time = todo_item[message_id]["end_time"].split("T")[0] | |
| location = todo_item[message_id]["location"] | |
| combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, "重要"]) | |
| return output_for_df | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| return [["error", "Filter API configuration incomplete", "-"]] | |
| headers = { | |
| "Authorization": f"Bearer {Filter_API_KEY}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "model": Filter_MODEL_NAME, | |
| "messages": llm_messages, | |
| "temperature": 0.2, | |
| "top_p": 0.95, | |
| "max_tokens": 1024, | |
| "stream": False | |
| } | |
| api_url = f"{Filter_API_BASE_URL}/chat/completions" | |
| try: | |
| response = requests.post(api_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| raw_llm_response = response.json()['choices'][0]['message']['content'] | |
| except requests.exceptions.RequestException as e: | |
| return [["error", f"Filter API call failed: {e}", "-"]] | |
| parsed_todos_data = json_parser(raw_llm_response) | |
| if "error" in parsed_todos_data: | |
| return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]] | |
| output_for_df = [] | |
| if isinstance(parsed_todos_data, dict): | |
| todo_info = None | |
| for key, value in parsed_todos_data.items(): | |
| if key == message_id or key == str(message_id): | |
| todo_info = value | |
| break | |
| if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False): | |
| todo_content = todo_info.get("todo_content", "未指定待办内容") | |
| end_time = todo_info.get("end_time") | |
| location = todo_info.get("location") | |
| urgency = todo_info.get("urgency", "unimportant") | |
| combined_content = todo_content | |
| if end_time and end_time != "null": | |
| try: | |
| date_part = end_time.split("T")[0] if "T" in end_time else end_time | |
| combined_content += f" (截止时间: {date_part}" | |
| except: | |
| combined_content += f" (截止时间: {end_time}" | |
| else: | |
| combined_content += " (" | |
| if location and location != "null": | |
| combined_content += f", 地点: {location})" | |
| else: | |
| combined_content += ")" | |
| urgency_display = "一般" | |
| if urgency == "urgent": | |
| urgency_display = "紧急" | |
| elif urgency == "important": | |
| urgency_display = "重要" | |
| output_for_df = [] | |
| output_for_df.append([1, combined_content, urgency_display]) | |
| else: | |
| output_for_df = [] | |
| output_for_df.append([1, "此消息不包含待办事项", "-"]) | |
| elif isinstance(parsed_todos_data, list): | |
| output_for_df = [] | |
| if not parsed_todos_data: | |
| return [[1, "未能生成待办事项", "-"]] | |
| for i, item in enumerate(parsed_todos_data): | |
| if isinstance(item, dict): | |
| todo_content = item.get('todo_content', item.get('content', 'N/A')) | |
| status = item.get('status', '未完成') | |
| urgency = item.get('urgency', 'normal') | |
| combined_content = todo_content | |
| if 'end_time' in item and item['end_time']: | |
| try: | |
| if isinstance(item['end_time'], str): | |
| date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time'] | |
| combined_content += f" (截止时间: {date_part}" | |
| else: | |
| combined_content += f" (截止时间: {str(item['end_time'])}" | |
| except Exception: | |
| combined_content += " (" | |
| else: | |
| combined_content += " (" | |
| if 'location' in item and item['location']: | |
| combined_content += f", 地点: {item['location']})" | |
| else: | |
| combined_content += ")" | |
| importance = "一般" | |
| if urgency == "urgent": | |
| importance = "紧急" | |
| elif urgency == "important": | |
| importance = "重要" | |
| output_for_df.append([i + 1, combined_content, importance]) | |
| else: | |
| try: | |
| item_str = str(item) if item is not None else "未知项目" | |
| output_for_df.append([i + 1, item_str, "一般"]) | |
| except Exception: | |
| output_for_df.append([i + 1, "处理错误的项目", "一般"]) | |
| if not output_for_df: | |
| return [["info", "未发现待办事项", "-"]] | |
| return output_for_df | |
| except Exception as e: | |
| return [["error", f"LLM call/parse error: {str(e)}", "-"]] | |
| # 这里------多模态数据从这里调用 | |
| def process(audio, image): | |
| """处理音频和图片输入,返回基本信息""" | |
| if audio is not None: | |
| sample_rate, audio_data = audio | |
| audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}" | |
| else: | |
| audio_info = "未收到音频" | |
| if image is not None: | |
| image_info = f"图片尺寸: {image.shape}" | |
| else: | |
| image_info = "未收到图片" | |
| return audio_info, image_info | |
| def respond(message, history, system_message, max_tokens, temperature, top_p, audio, image): | |
| """处理聊天响应,支持流式输出""" | |
| chat_messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| chat_messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| chat_messages.append({"role": "assistant", "content": val[1]}) | |
| chat_messages.append({"role": "user", "content": message}) | |
| chat_response_stream = "" | |
| if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: | |
| yield "Filter API 配置不完整,无法提供聊天回复。", [] | |
| return | |
| headers = { | |
| "Authorization": f"Bearer {Filter_API_KEY}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "model": Filter_MODEL_NAME, | |
| "messages": chat_messages, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_tokens": max_tokens, | |
| "stream": True | |
| } | |
| api_url = f"{Filter_API_BASE_URL}/chat/completions" | |
| try: | |
| response = requests.post(api_url, headers=headers, json=payload, stream=True) | |
| response.raise_for_status() | |
| for chunk in response.iter_content(chunk_size=None): | |
| if chunk: | |
| try: | |
| for line in chunk.decode('utf-8').splitlines(): | |
| if line.startswith('data: '): | |
| json_data = line[len('data: '):] | |
| if json_data.strip() == '[DONE]': | |
| break | |
| data = json.loads(json_data) | |
| token = data['choices'][0]['delta'].get('content', '') | |
| if token: | |
| chat_response_stream += token | |
| yield chat_response_stream, [] | |
| except json.JSONDecodeError: | |
| pass | |
| except Exception as e: | |
| yield chat_response_stream + f"\n\n错误: {e}", [] | |
| except requests.exceptions.RequestException as e: | |
| yield f"调用 NVIDIA API 失败: {e}", [] | |
| # 图片-多模态上传入口 | |
| with gr.Blocks() as app: | |
| gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Chat Interface") | |
| chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages") | |
| msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"]) | |
| image_input = gr.Image(label="上传图片", type="numpy") | |
| with gr.Accordion("高级设置", open=False): | |
| system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示") | |
| max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)") | |
| temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)") | |
| top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)") | |
| with gr.Row(): | |
| submit_btn = gr.Button("发送", variant="primary") | |
| clear_btn = gr.Button("清除聊天和ToDo") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Generated ToDo List") | |
| todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"], | |
| datatype=["number", "str", "str"], | |
| row_count=(0, "dynamic"), | |
| col_count=(3, "fixed"), | |
| label="待办事项列表") | |
| def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f): | |
| """处理用户提交的消息,生成聊天回复和待办事项""" | |
| # 首先处理多模态输入,获取多模态内容 | |
| multimodal_text_content = "" | |
| xunfei_config = get_hf_xunfei_config() | |
| xunfei_appid = xunfei_config.get('appid') | |
| xunfei_apikey = xunfei_config.get('apikey') | |
| xunfei_apisecret = xunfei_config.get('apisecret') | |
| # 添加调试日志 | |
| logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}") | |
| logger.info(f"讯飞配置状态 - appid: {bool(xunfei_appid)}, apikey: {bool(xunfei_apikey)}, apisecret: {bool(xunfei_apisecret)}") | |
| # 处理音频输入(独立处理) | |
| if audio_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret: | |
| logger.info("开始处理音频输入...") | |
| try: | |
| import tempfile | |
| import soundfile as sf | |
| import os | |
| audio_sample_rate, audio_data = audio_f | |
| logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}") | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
| sf.write(temp_audio.name, audio_data, audio_sample_rate) | |
| temp_audio_path = temp_audio.name | |
| logger.info(f"音频临时文件已保存: {temp_audio_path}") | |
| audio_text = audio_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_audio_path) | |
| logger.info(f"音频识别结果: {audio_text}") | |
| if audio_text: | |
| multimodal_text_content += f"音频内容: {audio_text}" | |
| os.unlink(temp_audio_path) | |
| logger.info("音频处理完成") | |
| except Exception as e: | |
| logger.error(f"音频处理错误: {str(e)}") | |
| elif audio_f is not None: | |
| logger.warning("音频文件存在但讯飞配置不完整,跳过音频处理") | |
| # 处理图像输入(独立处理) | |
| if image_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret: | |
| logger.info("开始处理图像输入...") | |
| try: | |
| import tempfile | |
| from PIL import Image | |
| import os | |
| logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}") | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image: | |
| if len(image_f.shape) == 3: # RGB图像 | |
| pil_image = Image.fromarray(image_f.astype('uint8'), 'RGB') | |
| else: # 灰度图像 | |
| pil_image = Image.fromarray(image_f.astype('uint8'), 'L') | |
| pil_image.save(temp_image.name, 'JPEG') | |
| temp_image_path = temp_image.name | |
| logger.info(f"图像临时文件已保存: {temp_image_path}") | |
| image_text = image_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_image_path) | |
| logger.info(f"图像识别结果: {image_text}") | |
| if image_text: | |
| if multimodal_text_content: # 如果已有音频内容,添加分隔符 | |
| multimodal_text_content += "\n" | |
| multimodal_text_content += f"图像内容: {image_text}" | |
| os.unlink(temp_image_path) | |
| logger.info("图像处理完成") | |
| except Exception as e: | |
| logger.error(f"图像处理错误: {str(e)}") | |
| elif image_f is not None: | |
| logger.warning("图像文件存在但讯飞配置不完整,跳过图像处理") | |
| # 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容 | |
| final_user_content = user_msg_content.strip() if user_msg_content else "" | |
| if not final_user_content and multimodal_text_content: | |
| final_user_content = multimodal_text_content | |
| logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}") | |
| elif final_user_content and multimodal_text_content: | |
| # 用户有文本输入,多模态内容作为补充 | |
| final_user_content = f"{final_user_content}\n{multimodal_text_content}" | |
| logger.info(f"用户有文本输入,多模态内容作为补充") | |
| # 如果最终还是没有任何内容,提供默认提示 | |
| if not final_user_content: | |
| final_user_content = "[无输入内容]" | |
| logger.warning("用户没有提供任何输入内容(文本、音频或图像)") | |
| logger.info(f"最终用户输入内容: {final_user_content}") | |
| # 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容 | |
| if not ch_history: ch_history = [] | |
| ch_history.append({"role": "user", "content": final_user_content}) | |
| yield ch_history, [] | |
| # 2. 流式生成机器人回复并更新聊天记录 | |
| formatted_hist_for_respond = [] | |
| temp_user_msg_for_hist = None | |
| for item_hist in ch_history[:-1]: | |
| if item_hist["role"] == "user": | |
| temp_user_msg_for_hist = item_hist["content"] | |
| elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None: | |
| formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"])) | |
| temp_user_msg_for_hist = None | |
| elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None: | |
| formatted_hist_for_respond.append(("", item_hist["content"])) | |
| ch_history.append({"role": "assistant", "content": ""}) | |
| full_bot_response = "" | |
| # 使用最终确定的用户内容进行对话 | |
| for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f): | |
| full_bot_response = bot_response_token | |
| ch_history[-1]["content"] = full_bot_response | |
| yield ch_history, [] | |
| # 3. 生成 ToDoList - 使用最终确定的用户内容 | |
| text_for_todo = final_user_content | |
| # 添加日志:输出用于ToDo生成的内容 | |
| logger.info(f"用于ToDo生成的内容: {text_for_todo}") | |
| current_todos_list = [] | |
| filtered_result = filter_message_with_llm(text_for_todo) | |
| if isinstance(filtered_result, dict) and "error" in filtered_result: | |
| current_todos_list = [["Error", filtered_result['error'], "Filter Failed"]] | |
| elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他": | |
| current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] | |
| elif isinstance(filtered_result, list): | |
| category = None | |
| if not filtered_result: | |
| if text_for_todo: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) | |
| yield ch_history, current_todos_list | |
| return | |
| valid_item = None | |
| for item in filtered_result: | |
| if isinstance(item, dict): | |
| valid_item = item | |
| if "分类" in item: | |
| category = item["分类"] | |
| break | |
| if valid_item is None: | |
| if text_for_todo: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) | |
| yield ch_history, current_todos_list | |
| return | |
| if category == "其他": | |
| current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] | |
| else: | |
| if text_for_todo: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) | |
| else: | |
| if text_for_todo: | |
| msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" | |
| current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) | |
| yield ch_history, current_todos_list | |
| submit_btn.click( | |
| handle_submit, | |
| [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input], | |
| [chatbot, todolist_df] | |
| ) | |
| msg.submit( | |
| handle_submit, | |
| [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input], | |
| [chatbot, todolist_df] | |
| ) | |
| def clear_all(): | |
| """清除所有聊天记录和待办事项""" | |
| return None, None, "" | |
| clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False) | |
| #多模态标签也 | |
| with gr.Tab("Audio/Image Processing (Original)"): | |
| gr.Markdown("## 处理音频和图片") | |
| audio_processor = gr.Audio(label="上传音频", type="numpy") | |
| image_processor = gr.Image(label="上传图片", type="numpy") | |
| process_btn = gr.Button("处理", variant="primary") | |
| audio_output = gr.Textbox(label="音频信息") | |
| image_output = gr.Textbox(label="图片信息") | |
| process_btn.click( | |
| process, | |
| inputs=[audio_processor, image_processor], | |
| outputs=[audio_output, image_output] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(debug=False) |