TruDecide.ai / inference.py
shuyouxingjie's picture
Upload 5 files
332b2aa verified
import json
import os
import re
import random
try:
from opencc import OpenCC
cc = OpenCC('s2t') # 简体到繁体
cc_t2s = OpenCC('t2s') # 繁体到简体
has_opencc = True
except:
has_opencc = False
class TruDecide:
def __init__(self, model_path="."):
# 加载知识库
kb_path = os.path.join(model_path, "knowledge_base.json")
with open(kb_path, "r", encoding="utf-8") as f:
self.knowledge_base = json.load(f)
self.general_responses = [
"作为TruDecide,我专注于市场分析和金融概念解释。请问您有什么具体的金融市场问题需要了解?",
"我是TruDecide,您的市场分析助手。我可以解答关于股票、债券、ETF、市场趋势等金融问题。请问您想了解什么具体的金融知识?",
"很抱歉,我没有足够的信息来回答这个问题。作为TruDecide,我主要提供金融市场方面的分析和建议。请问您有什么关于投资或金融市场的问题吗?",
"这个问题超出了我的专业范围。作为TruDecide,我主要关注金融市场分析和投资概念解释。您有任何金融相关的问题,我很乐意为您解答。",
]
def answer(self, query, threshold=0.5):
"""给定问题返回最佳匹配的回答"""
query = query.lower().strip("??")
# 1. 直接匹配
if query in self.knowledge_base:
return self.knowledge_base[query]
# 2. 简繁转换后匹配
if has_opencc:
t_query = cc.convert(query)
if t_query in self.knowledge_base:
return self.knowledge_base[t_query]
s_query = cc_t2s.convert(query)
if s_query in self.knowledge_base:
return self.knowledge_base[s_query]
# 3. 关键词提取后匹配
keyword_query = re.sub(r'^(什么是|解释一下|请告诉我|你能介绍|介绍|说明|你是|你能|你的|你有什么|你叫什么|你叫)', '', query).strip()
if keyword_query in self.knowledge_base:
return self.knowledge_base[keyword_query]
# 4. 部分匹配
best_match = None
highest_score = 0
for key in self.knowledge_base.keys():
# 计算相似度 (简单版)
words_q = set(query.split())
words_k = set(key.split())
common_words = words_q.intersection(words_k)
if len(words_q) == 0 or len(words_k) == 0:
continue
score = len(common_words) / max(len(words_q), len(words_k))
if score > highest_score:
highest_score = score
best_match = key
# 如果最佳匹配超过阈值,返回对应回答
if best_match and highest_score >= threshold:
return self.knowledge_base[best_match]
# 5. 特殊处理身份问题
if any(word in query for word in ["你是谁", "你的名字", "你是什么", "你能做什么", "你叫什么"]):
return self.knowledge_base.get("你是谁", "我是TruDecide,你的智能市场分析助手。")
# 6. 返回通用回复
return random.choice(self.general_responses)
# 用法示例
if __name__ == "__main__":
model = TruDecide()
print(model.answer("什么是股票市场?"))
print(model.answer("你是谁?"))