|
|
import gradio as gr |
|
|
import json |
|
|
from pathlib import Path |
|
|
import yaml |
|
|
import re |
|
|
import logging |
|
|
import io |
|
|
import sys |
|
|
import re |
|
|
from datetime import datetime, timezone, timedelta |
|
|
import requests |
|
|
from tools import * |
|
|
|
|
|
|
|
|
CONFIG = None |
|
|
HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml" |
|
|
|
|
|
def load_hf_config(): |
|
|
"""加载YAML配置文件""" |
|
|
global CONFIG |
|
|
if CONFIG is None: |
|
|
try: |
|
|
with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f: |
|
|
CONFIG = yaml.safe_load(f) |
|
|
print(f"✅ 配置已加载: {HF_CONFIG_PATH}") |
|
|
except FileNotFoundError: |
|
|
print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。") |
|
|
CONFIG = {} |
|
|
except Exception as e: |
|
|
print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}") |
|
|
CONFIG = {} |
|
|
return CONFIG |
|
|
|
|
|
def get_hf_openai_config(): |
|
|
"""获取OpenAI API配置""" |
|
|
config = load_hf_config() |
|
|
return config.get('openai', {}) |
|
|
|
|
|
def get_hf_openai_filter_config(): |
|
|
"""获取Filter API配置""" |
|
|
config = load_hf_config() |
|
|
return config.get('openai_filter', {}) |
|
|
|
|
|
def get_hf_xunfei_config(): |
|
|
"""获取讯飞API配置""" |
|
|
config = load_hf_config() |
|
|
return config.get('xunfei', {}) |
|
|
|
|
|
def get_hf_paths_config(): |
|
|
"""获取文件路径配置""" |
|
|
config = load_hf_config() |
|
|
base = Path(__file__).resolve().parent |
|
|
paths_cfg = config.get('paths', {}) |
|
|
return { |
|
|
'base_dir': base, |
|
|
'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'), |
|
|
'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'), |
|
|
'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'), |
|
|
} |
|
|
|
|
|
llm_config = get_hf_openai_config() |
|
|
NVIDIA_API_BASE_URL = llm_config.get('base_url') |
|
|
NVIDIA_API_KEY = llm_config.get('api_key') |
|
|
NVIDIA_MODEL_NAME = llm_config.get('model') |
|
|
|
|
|
filter_config = get_hf_openai_filter_config() |
|
|
Filter_API_BASE_URL = filter_config.get('base_url_filter') |
|
|
Filter_API_KEY = filter_config.get('api_key_filter') |
|
|
Filter_MODEL_NAME = filter_config.get('model_filter') |
|
|
|
|
|
if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME: |
|
|
print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。") |
|
|
NVIDIA_API_BASE_URL = "" |
|
|
NVIDIA_API_KEY = "" |
|
|
NVIDIA_MODEL_NAME = "" |
|
|
|
|
|
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: |
|
|
print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。") |
|
|
Filter_API_BASE_URL = "" |
|
|
Filter_API_KEY = "" |
|
|
Filter_MODEL_NAME = "" |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_single_few_shot_file_hf(file_path: Path) -> str: |
|
|
"""加载单个few-shot示例文件并转义大括号""" |
|
|
try: |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
escaped_content = content.replace('{', '{{').replace('}', '}}') |
|
|
return escaped_content |
|
|
except FileNotFoundError: |
|
|
return "" |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
PROMPT_TEMPLATE_CONTENT = "" |
|
|
TRUE_POSITIVE_EXAMPLES_CONTENT = "" |
|
|
FALSE_POSITIVE_EXAMPLES_CONTENT = "" |
|
|
|
|
|
def load_prompt_data_hf(): |
|
|
"""加载提示词模板和示例数据""" |
|
|
global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT |
|
|
paths = get_hf_paths_config() |
|
|
try: |
|
|
with open(paths['prompt_template'], 'r', encoding='utf-8') as f: |
|
|
PROMPT_TEMPLATE_CONTENT = f.read() |
|
|
except FileNotFoundError: |
|
|
PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found." |
|
|
|
|
|
TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples']) |
|
|
FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples']) |
|
|
|
|
|
load_prompt_data_hf() |
|
|
|
|
|
def _process_parsed_json(parsed_data): |
|
|
"""处理解析后的JSON数据,确保格式正确""" |
|
|
try: |
|
|
if isinstance(parsed_data, list): |
|
|
if not parsed_data: |
|
|
return [{}] |
|
|
|
|
|
processed_list = [] |
|
|
for item in parsed_data: |
|
|
if isinstance(item, dict): |
|
|
processed_list.append(item) |
|
|
else: |
|
|
try: |
|
|
processed_list.append({"content": str(item)}) |
|
|
except: |
|
|
processed_list.append({"content": "无法转换的项目"}) |
|
|
|
|
|
if not processed_list: |
|
|
return [{}] |
|
|
|
|
|
return processed_list |
|
|
|
|
|
elif isinstance(parsed_data, dict): |
|
|
return parsed_data |
|
|
|
|
|
else: |
|
|
return {"content": str(parsed_data)} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Error processing parsed JSON: {e}"} |
|
|
|
|
|
def json_parser(text: str) -> dict: |
|
|
"""从文本中解析JSON数据,支持多种格式""" |
|
|
try: |
|
|
try: |
|
|
parsed_data = json.loads(text) |
|
|
return _process_parsed_json(parsed_data) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL) |
|
|
if match: |
|
|
json_str = match.group(1).strip() |
|
|
json_str = re.sub(r',\s*]', ']', json_str) |
|
|
json_str = re.sub(r',\s*}', '}', json_str) |
|
|
try: |
|
|
parsed_data = json.loads(json_str) |
|
|
return _process_parsed_json(parsed_data) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL) |
|
|
if array_match: |
|
|
potential_json = array_match.group(0).strip() |
|
|
try: |
|
|
parsed_data = json.loads(potential_json) |
|
|
return _process_parsed_json(parsed_data) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
object_match = re.search(r'\{.*?\}', text, re.DOTALL) |
|
|
if object_match: |
|
|
potential_json = object_match.group(0).strip() |
|
|
try: |
|
|
parsed_data = json.loads(potential_json) |
|
|
return _process_parsed_json(parsed_data) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
return {"error": "No valid JSON block found or failed to parse", "raw_text": text} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text} |
|
|
|
|
|
def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"): |
|
|
"""使用LLM对消息进行分类过滤""" |
|
|
mock_data = [(text_input, message_id)] |
|
|
|
|
|
system_prompt = """ |
|
|
# 角色 |
|
|
你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。 |
|
|
|
|
|
# 任务 |
|
|
对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。 |
|
|
主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略 |
|
|
|
|
|
# 要求 |
|
|
1. 以json格式输出 |
|
|
2. content简洁提炼关键词,字符数<20以内 |
|
|
3. 输入条数和输出条数完全一样 |
|
|
|
|
|
# 输出示例 |
|
|
``` |
|
|
[ |
|
|
{"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"}, |
|
|
{"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"} |
|
|
] |
|
|
``` |
|
|
""" |
|
|
|
|
|
llm_messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": str(mock_data)} |
|
|
] |
|
|
|
|
|
try: |
|
|
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: |
|
|
return [{"error": "Filter API configuration incomplete", "-": "-"}] |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {Filter_API_KEY}", |
|
|
"Accept": "application/json" |
|
|
} |
|
|
payload = { |
|
|
"model": Filter_MODEL_NAME, |
|
|
"messages": llm_messages, |
|
|
"temperature": 0.0, |
|
|
"top_p": 0.95, |
|
|
"max_tokens": 1024, |
|
|
"stream": False |
|
|
} |
|
|
|
|
|
api_url = f"{Filter_API_BASE_URL}/chat/completions" |
|
|
|
|
|
try: |
|
|
response = requests.post(api_url, headers=headers, json=payload) |
|
|
response.raise_for_status() |
|
|
raw_llm_response = response.json()["choices"][0]["message"]["content"] |
|
|
except requests.exceptions.RequestException as e: |
|
|
return [{"error": f"Filter API call failed: {e}", "-": "-"}] |
|
|
|
|
|
raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "") |
|
|
parsed_filter_data = json_parser(raw_llm_response) |
|
|
|
|
|
if "error" in parsed_filter_data: |
|
|
return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}] |
|
|
|
|
|
if isinstance(parsed_filter_data, list) and parsed_filter_data: |
|
|
for item in parsed_filter_data: |
|
|
if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""): |
|
|
item["分类"] = "其他" |
|
|
|
|
|
request_id_list = {message_id} |
|
|
response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)} |
|
|
diff = request_id_list - response_id_list |
|
|
|
|
|
if diff: |
|
|
for missed_id in diff: |
|
|
parsed_filter_data.append({ |
|
|
"message_id": missed_id, |
|
|
"content": text_input[:20], |
|
|
"物流取件": 0, |
|
|
"欠费缴纳": 0, |
|
|
"待付(还)款": 0, |
|
|
"会议邀约": 0, |
|
|
"其他": 100, |
|
|
"分类": "其他" |
|
|
}) |
|
|
|
|
|
return parsed_filter_data |
|
|
else: |
|
|
return [{ |
|
|
"message_id": message_id, |
|
|
"content": text_input[:20], |
|
|
"物流取件": 0, |
|
|
"欠费缴纳": 0, |
|
|
"待付(还)款": 0, |
|
|
"会议邀约": 0, |
|
|
"其他": 100, |
|
|
"分类": "其他", |
|
|
"error": "Filter LLM returned empty or unexpected format" |
|
|
}] |
|
|
|
|
|
except Exception as e: |
|
|
return [{ |
|
|
"message_id": message_id, |
|
|
"content": text_input[:20], |
|
|
"物流取件": 0, |
|
|
"欠费缴纳": 0, |
|
|
"待付(还)款": 0, |
|
|
"会议邀约": 0, |
|
|
"其他": 100, |
|
|
"分类": "其他", |
|
|
"error": f"Filter LLM call/parse error: {str(e)}" |
|
|
}] |
|
|
|
|
|
def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"): |
|
|
"""从文本生成待办事项列表""" |
|
|
if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT: |
|
|
return [["error", "Prompt template not loaded", "-"]] |
|
|
|
|
|
current_time_iso = datetime.now(timezone.utc).isoformat() |
|
|
content_escaped = text_input.replace('{', '{{').replace('}', '}}') |
|
|
|
|
|
formatted_prompt = PROMPT_TEMPLATE_CONTENT.format( |
|
|
true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT, |
|
|
false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT, |
|
|
current_time=current_time_iso, |
|
|
message_id=message_id, |
|
|
content_escaped=content_escaped |
|
|
) |
|
|
|
|
|
enhanced_prompt = formatted_prompt + """ |
|
|
|
|
|
# 重要提示 |
|
|
请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。 |
|
|
你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。 |
|
|
""" |
|
|
|
|
|
llm_messages = [ |
|
|
{"role": "user", "content": enhanced_prompt} |
|
|
] |
|
|
|
|
|
try: |
|
|
if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input): |
|
|
todo_item = { |
|
|
message_id: { |
|
|
"is_todo": True, |
|
|
"end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(), |
|
|
"location": "线上:中国移动APP", |
|
|
"todo_content": "缴纳话费", |
|
|
"urgency": "important" |
|
|
} |
|
|
} |
|
|
|
|
|
todo_content = "缴纳话费" |
|
|
end_time = todo_item[message_id]["end_time"].split("T")[0] |
|
|
location = todo_item[message_id]["location"] |
|
|
|
|
|
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" |
|
|
|
|
|
output_for_df = [] |
|
|
output_for_df.append([1, combined_content, "重要"]) |
|
|
|
|
|
return output_for_df |
|
|
|
|
|
elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input): |
|
|
meeting_time = None |
|
|
meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})' |
|
|
meeting_match = re.search(meeting_pattern, text_input) |
|
|
|
|
|
if meeting_match: |
|
|
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat() |
|
|
else: |
|
|
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat() |
|
|
|
|
|
todo_item = { |
|
|
message_id: { |
|
|
"is_todo": True, |
|
|
"end_time": meeting_time, |
|
|
"location": "线上:会议软件", |
|
|
"todo_content": "参加会议", |
|
|
"urgency": "important" |
|
|
} |
|
|
} |
|
|
|
|
|
todo_content = "参加会议" |
|
|
end_time = todo_item[message_id]["end_time"].split("T")[0] |
|
|
location = todo_item[message_id]["location"] |
|
|
|
|
|
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" |
|
|
|
|
|
output_for_df = [] |
|
|
output_for_df.append([1, combined_content, "重要"]) |
|
|
|
|
|
return output_for_df |
|
|
|
|
|
elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input): |
|
|
pickup_code = None |
|
|
code_pattern = r'取件码[是为:]?\s*(\d{4,6})' |
|
|
code_match = re.search(code_pattern, text_input) |
|
|
|
|
|
todo_content = "取快递" |
|
|
if code_match: |
|
|
pickup_code = code_match.group(1) |
|
|
todo_content = f"取快递(取件码:{pickup_code})" |
|
|
|
|
|
todo_item = { |
|
|
message_id: { |
|
|
"is_todo": True, |
|
|
"end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(), |
|
|
"location": "线下:快递柜", |
|
|
"todo_content": todo_content, |
|
|
"urgency": "important" |
|
|
} |
|
|
} |
|
|
|
|
|
end_time = todo_item[message_id]["end_time"].split("T")[0] |
|
|
location = todo_item[message_id]["location"] |
|
|
|
|
|
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})" |
|
|
|
|
|
output_for_df = [] |
|
|
output_for_df.append([1, combined_content, "重要"]) |
|
|
|
|
|
return output_for_df |
|
|
|
|
|
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: |
|
|
return [["error", "Filter API configuration incomplete", "-"]] |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {Filter_API_KEY}", |
|
|
"Accept": "application/json" |
|
|
} |
|
|
payload = { |
|
|
"model": Filter_MODEL_NAME, |
|
|
"messages": llm_messages, |
|
|
"temperature": 0.2, |
|
|
"top_p": 0.95, |
|
|
"max_tokens": 1024, |
|
|
"stream": False |
|
|
} |
|
|
|
|
|
api_url = f"{Filter_API_BASE_URL}/chat/completions" |
|
|
|
|
|
try: |
|
|
response = requests.post(api_url, headers=headers, json=payload) |
|
|
response.raise_for_status() |
|
|
raw_llm_response = response.json()['choices'][0]['message']['content'] |
|
|
except requests.exceptions.RequestException as e: |
|
|
return [["error", f"Filter API call failed: {e}", "-"]] |
|
|
|
|
|
parsed_todos_data = json_parser(raw_llm_response) |
|
|
|
|
|
if "error" in parsed_todos_data: |
|
|
return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]] |
|
|
|
|
|
output_for_df = [] |
|
|
|
|
|
if isinstance(parsed_todos_data, dict): |
|
|
todo_info = None |
|
|
for key, value in parsed_todos_data.items(): |
|
|
if key == message_id or key == str(message_id): |
|
|
todo_info = value |
|
|
break |
|
|
|
|
|
if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False): |
|
|
todo_content = todo_info.get("todo_content", "未指定待办内容") |
|
|
end_time = todo_info.get("end_time") |
|
|
location = todo_info.get("location") |
|
|
urgency = todo_info.get("urgency", "unimportant") |
|
|
|
|
|
combined_content = todo_content |
|
|
|
|
|
if end_time and end_time != "null": |
|
|
try: |
|
|
date_part = end_time.split("T")[0] if "T" in end_time else end_time |
|
|
combined_content += f" (截止时间: {date_part}" |
|
|
except: |
|
|
combined_content += f" (截止时间: {end_time}" |
|
|
else: |
|
|
combined_content += " (" |
|
|
|
|
|
if location and location != "null": |
|
|
combined_content += f", 地点: {location})" |
|
|
else: |
|
|
combined_content += ")" |
|
|
|
|
|
urgency_display = "一般" |
|
|
if urgency == "urgent": |
|
|
urgency_display = "紧急" |
|
|
elif urgency == "important": |
|
|
urgency_display = "重要" |
|
|
|
|
|
output_for_df = [] |
|
|
output_for_df.append([1, combined_content, urgency_display]) |
|
|
else: |
|
|
output_for_df = [] |
|
|
output_for_df.append([1, "此消息不包含待办事项", "-"]) |
|
|
|
|
|
elif isinstance(parsed_todos_data, list): |
|
|
output_for_df = [] |
|
|
|
|
|
if not parsed_todos_data: |
|
|
return [[1, "未能生成待办事项", "-"]] |
|
|
|
|
|
for i, item in enumerate(parsed_todos_data): |
|
|
if isinstance(item, dict): |
|
|
todo_content = item.get('todo_content', item.get('content', 'N/A')) |
|
|
status = item.get('status', '未完成') |
|
|
urgency = item.get('urgency', 'normal') |
|
|
|
|
|
combined_content = todo_content |
|
|
|
|
|
if 'end_time' in item and item['end_time']: |
|
|
try: |
|
|
if isinstance(item['end_time'], str): |
|
|
date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time'] |
|
|
combined_content += f" (截止时间: {date_part}" |
|
|
else: |
|
|
combined_content += f" (截止时间: {str(item['end_time'])}" |
|
|
except Exception: |
|
|
combined_content += " (" |
|
|
else: |
|
|
combined_content += " (" |
|
|
|
|
|
if 'location' in item and item['location']: |
|
|
combined_content += f", 地点: {item['location']})" |
|
|
else: |
|
|
combined_content += ")" |
|
|
|
|
|
importance = "一般" |
|
|
if urgency == "urgent": |
|
|
importance = "紧急" |
|
|
elif urgency == "important": |
|
|
importance = "重要" |
|
|
|
|
|
output_for_df.append([i + 1, combined_content, importance]) |
|
|
else: |
|
|
try: |
|
|
item_str = str(item) if item is not None else "未知项目" |
|
|
output_for_df.append([i + 1, item_str, "一般"]) |
|
|
except Exception: |
|
|
output_for_df.append([i + 1, "处理错误的项目", "一般"]) |
|
|
|
|
|
if not output_for_df: |
|
|
return [["info", "未发现待办事项", "-"]] |
|
|
|
|
|
return output_for_df |
|
|
|
|
|
except Exception as e: |
|
|
return [["error", f"LLM call/parse error: {str(e)}", "-"]] |
|
|
|
|
|
def process(audio, image): |
|
|
"""处理音频和图片输入,返回基本信息""" |
|
|
if audio is not None: |
|
|
sample_rate, audio_data = audio |
|
|
audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}" |
|
|
else: |
|
|
audio_info = "未收到音频" |
|
|
|
|
|
if image is not None: |
|
|
image_info = f"图片尺寸: {image.shape}" |
|
|
else: |
|
|
image_info = "未收到图片" |
|
|
|
|
|
return audio_info, image_info |
|
|
|
|
|
def respond(message, history, system_message, max_tokens, temperature, top_p, audio, image): |
|
|
"""处理聊天响应,支持流式输出""" |
|
|
chat_messages = [{"role": "system", "content": system_message}] |
|
|
for val in history: |
|
|
if val[0]: |
|
|
chat_messages.append({"role": "user", "content": val[0]}) |
|
|
if val[1]: |
|
|
chat_messages.append({"role": "assistant", "content": val[1]}) |
|
|
chat_messages.append({"role": "user", "content": message}) |
|
|
|
|
|
chat_response_stream = "" |
|
|
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME: |
|
|
yield "Filter API 配置不完整,无法提供聊天回复。", [] |
|
|
return |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {Filter_API_KEY}", |
|
|
"Accept": "application/json" |
|
|
} |
|
|
payload = { |
|
|
"model": Filter_MODEL_NAME, |
|
|
"messages": chat_messages, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"max_tokens": max_tokens, |
|
|
"stream": True |
|
|
} |
|
|
api_url = f"{Filter_API_BASE_URL}/chat/completions" |
|
|
|
|
|
try: |
|
|
response = requests.post(api_url, headers=headers, json=payload, stream=True) |
|
|
response.raise_for_status() |
|
|
|
|
|
for chunk in response.iter_content(chunk_size=None): |
|
|
if chunk: |
|
|
try: |
|
|
for line in chunk.decode('utf-8').splitlines(): |
|
|
if line.startswith('data: '): |
|
|
json_data = line[len('data: '):] |
|
|
if json_data.strip() == '[DONE]': |
|
|
break |
|
|
data = json.loads(json_data) |
|
|
token = data['choices'][0]['delta'].get('content', '') |
|
|
if token: |
|
|
chat_response_stream += token |
|
|
yield chat_response_stream, [] |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
except Exception as e: |
|
|
yield chat_response_stream + f"\n\n错误: {e}", [] |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
yield f"调用 NVIDIA API 失败: {e}", [] |
|
|
|
|
|
with gr.Blocks() as app: |
|
|
gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("## Chat Interface") |
|
|
chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages") |
|
|
msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...") |
|
|
|
|
|
with gr.Row(): |
|
|
audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"]) |
|
|
image_input = gr.Image(label="上传图片", type="numpy") |
|
|
|
|
|
with gr.Accordion("高级设置", open=False): |
|
|
system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示") |
|
|
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)") |
|
|
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)") |
|
|
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("发送", variant="primary") |
|
|
clear_btn = gr.Button("清除聊天和ToDo") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("## Generated ToDo List") |
|
|
todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"], |
|
|
datatype=["number", "str", "str"], |
|
|
row_count=(0, "dynamic"), |
|
|
col_count=(3, "fixed"), |
|
|
label="待办事项列表") |
|
|
|
|
|
def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f): |
|
|
"""处理用户提交的消息,生成聊天回复和待办事项""" |
|
|
|
|
|
multimodal_text_content = "" |
|
|
xunfei_config = get_hf_xunfei_config() |
|
|
xunfei_appid = xunfei_config.get('appid') |
|
|
xunfei_apikey = xunfei_config.get('apikey') |
|
|
xunfei_apisecret = xunfei_config.get('apisecret') |
|
|
|
|
|
|
|
|
logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}") |
|
|
logger.info(f"讯飞配置状态 - appid: {bool(xunfei_appid)}, apikey: {bool(xunfei_apikey)}, apisecret: {bool(xunfei_apisecret)}") |
|
|
|
|
|
|
|
|
if audio_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret: |
|
|
logger.info("开始处理音频输入...") |
|
|
try: |
|
|
import tempfile |
|
|
import soundfile as sf |
|
|
import os |
|
|
|
|
|
audio_sample_rate, audio_data = audio_f |
|
|
logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: |
|
|
sf.write(temp_audio.name, audio_data, audio_sample_rate) |
|
|
temp_audio_path = temp_audio.name |
|
|
logger.info(f"音频临时文件已保存: {temp_audio_path}") |
|
|
|
|
|
audio_text = audio_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_audio_path) |
|
|
logger.info(f"音频识别结果: {audio_text}") |
|
|
if audio_text: |
|
|
multimodal_text_content += f"音频内容: {audio_text}" |
|
|
|
|
|
os.unlink(temp_audio_path) |
|
|
logger.info("音频处理完成") |
|
|
except Exception as e: |
|
|
logger.error(f"音频处理错误: {str(e)}") |
|
|
elif audio_f is not None: |
|
|
logger.warning("音频文件存在但讯飞配置不完整,跳过音频处理") |
|
|
|
|
|
|
|
|
if image_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret: |
|
|
logger.info("开始处理图像输入...") |
|
|
try: |
|
|
import tempfile |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image: |
|
|
if len(image_f.shape) == 3: |
|
|
pil_image = Image.fromarray(image_f.astype('uint8'), 'RGB') |
|
|
else: |
|
|
pil_image = Image.fromarray(image_f.astype('uint8'), 'L') |
|
|
|
|
|
pil_image.save(temp_image.name, 'JPEG') |
|
|
temp_image_path = temp_image.name |
|
|
logger.info(f"图像临时文件已保存: {temp_image_path}") |
|
|
|
|
|
image_text = image_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_image_path) |
|
|
logger.info(f"图像识别结果: {image_text}") |
|
|
if image_text: |
|
|
if multimodal_text_content: |
|
|
multimodal_text_content += "\n" |
|
|
multimodal_text_content += f"图像内容: {image_text}" |
|
|
|
|
|
os.unlink(temp_image_path) |
|
|
logger.info("图像处理完成") |
|
|
except Exception as e: |
|
|
logger.error(f"图像处理错误: {str(e)}") |
|
|
elif image_f is not None: |
|
|
logger.warning("图像文件存在但讯飞配置不完整,跳过图像处理") |
|
|
|
|
|
|
|
|
final_user_content = user_msg_content.strip() if user_msg_content else "" |
|
|
if not final_user_content and multimodal_text_content: |
|
|
final_user_content = multimodal_text_content |
|
|
logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}") |
|
|
elif final_user_content and multimodal_text_content: |
|
|
|
|
|
final_user_content = f"{final_user_content}\n{multimodal_text_content}" |
|
|
logger.info(f"用户有文本输入,多模态内容作为补充") |
|
|
|
|
|
|
|
|
if not final_user_content: |
|
|
final_user_content = "[无输入内容]" |
|
|
logger.warning("用户没有提供任何输入内容(文本、音频或图像)") |
|
|
|
|
|
logger.info(f"最终用户输入内容: {final_user_content}") |
|
|
|
|
|
|
|
|
if not ch_history: ch_history = [] |
|
|
ch_history.append({"role": "user", "content": final_user_content}) |
|
|
yield ch_history, [] |
|
|
|
|
|
|
|
|
formatted_hist_for_respond = [] |
|
|
temp_user_msg_for_hist = None |
|
|
for item_hist in ch_history[:-1]: |
|
|
if item_hist["role"] == "user": |
|
|
temp_user_msg_for_hist = item_hist["content"] |
|
|
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None: |
|
|
formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"])) |
|
|
temp_user_msg_for_hist = None |
|
|
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None: |
|
|
formatted_hist_for_respond.append(("", item_hist["content"])) |
|
|
|
|
|
ch_history.append({"role": "assistant", "content": ""}) |
|
|
|
|
|
full_bot_response = "" |
|
|
|
|
|
for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f): |
|
|
full_bot_response = bot_response_token |
|
|
ch_history[-1]["content"] = full_bot_response |
|
|
yield ch_history, [] |
|
|
|
|
|
|
|
|
text_for_todo = final_user_content |
|
|
|
|
|
|
|
|
logger.info(f"用于ToDo生成的内容: {text_for_todo}") |
|
|
current_todos_list = [] |
|
|
|
|
|
filtered_result = filter_message_with_llm(text_for_todo) |
|
|
|
|
|
if isinstance(filtered_result, dict) and "error" in filtered_result: |
|
|
current_todos_list = [["Error", filtered_result['error'], "Filter Failed"]] |
|
|
elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他": |
|
|
current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] |
|
|
elif isinstance(filtered_result, list): |
|
|
category = None |
|
|
|
|
|
if not filtered_result: |
|
|
if text_for_todo: |
|
|
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" |
|
|
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) |
|
|
yield ch_history, current_todos_list |
|
|
return |
|
|
|
|
|
valid_item = None |
|
|
for item in filtered_result: |
|
|
if isinstance(item, dict): |
|
|
valid_item = item |
|
|
if "分类" in item: |
|
|
category = item["分类"] |
|
|
break |
|
|
|
|
|
if valid_item is None: |
|
|
if text_for_todo: |
|
|
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" |
|
|
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) |
|
|
yield ch_history, current_todos_list |
|
|
return |
|
|
|
|
|
if category == "其他": |
|
|
current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]] |
|
|
else: |
|
|
if text_for_todo: |
|
|
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" |
|
|
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) |
|
|
else: |
|
|
if text_for_todo: |
|
|
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}" |
|
|
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo) |
|
|
|
|
|
yield ch_history, current_todos_list |
|
|
|
|
|
submit_btn.click( |
|
|
handle_submit, |
|
|
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input], |
|
|
[chatbot, todolist_df] |
|
|
) |
|
|
msg.submit( |
|
|
handle_submit, |
|
|
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input], |
|
|
[chatbot, todolist_df] |
|
|
) |
|
|
|
|
|
def clear_all(): |
|
|
"""清除所有聊天记录和待办事项""" |
|
|
return None, None, "" |
|
|
clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False) |
|
|
|
|
|
with gr.Tab("Audio/Image Processing (Original)"): |
|
|
gr.Markdown("## 处理音频和图片") |
|
|
audio_processor = gr.Audio(label="上传音频", type="numpy") |
|
|
image_processor = gr.Image(label="上传图片", type="numpy") |
|
|
process_btn = gr.Button("处理", variant="primary") |
|
|
audio_output = gr.Textbox(label="音频信息") |
|
|
image_output = gr.Textbox(label="图片信息") |
|
|
|
|
|
process_btn.click( |
|
|
process, |
|
|
inputs=[audio_processor, image_processor], |
|
|
outputs=[audio_output, image_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch(debug=False) |