|
|
| import json |
| import os |
| import re |
| import random |
|
|
| try: |
| from opencc import OpenCC |
| cc = OpenCC('s2t') |
| cc_t2s = OpenCC('t2s') |
| has_opencc = True |
| except: |
| has_opencc = False |
|
|
| class TruDecide: |
| def __init__(self, model_path="."): |
| |
| kb_path = os.path.join(model_path, "knowledge_base.json") |
| with open(kb_path, "r", encoding="utf-8") as f: |
| self.knowledge_base = json.load(f) |
| |
| self.general_responses = [ |
| "作为TruDecide,我专注于市场分析和金融概念解释。请问您有什么具体的金融市场问题需要了解?", |
| "我是TruDecide,您的市场分析助手。我可以解答关于股票、债券、ETF、市场趋势等金融问题。请问您想了解什么具体的金融知识?", |
| "很抱歉,我没有足够的信息来回答这个问题。作为TruDecide,我主要提供金融市场方面的分析和建议。请问您有什么关于投资或金融市场的问题吗?", |
| "这个问题超出了我的专业范围。作为TruDecide,我主要关注金融市场分析和投资概念解释。您有任何金融相关的问题,我很乐意为您解答。", |
| ] |
| |
| def answer(self, query, threshold=0.5): |
| """给定问题返回最佳匹配的回答""" |
| query = query.lower().strip("??") |
| |
| |
| if query in self.knowledge_base: |
| return self.knowledge_base[query] |
| |
| |
| if has_opencc: |
| t_query = cc.convert(query) |
| if t_query in self.knowledge_base: |
| return self.knowledge_base[t_query] |
| |
| s_query = cc_t2s.convert(query) |
| if s_query in self.knowledge_base: |
| return self.knowledge_base[s_query] |
| |
| |
| keyword_query = re.sub(r'^(什么是|解释一下|请告诉我|你能介绍|介绍|说明|你是|你能|你的|你有什么|你叫什么|你叫)', '', query).strip() |
| if keyword_query in self.knowledge_base: |
| return self.knowledge_base[keyword_query] |
| |
| |
| best_match = None |
| highest_score = 0 |
| |
| for key in self.knowledge_base.keys(): |
| |
| words_q = set(query.split()) |
| words_k = set(key.split()) |
| common_words = words_q.intersection(words_k) |
| |
| if len(words_q) == 0 or len(words_k) == 0: |
| continue |
| |
| score = len(common_words) / max(len(words_q), len(words_k)) |
| |
| if score > highest_score: |
| highest_score = score |
| best_match = key |
| |
| |
| if best_match and highest_score >= threshold: |
| return self.knowledge_base[best_match] |
| |
| |
| if any(word in query for word in ["你是谁", "你的名字", "你是什么", "你能做什么", "你叫什么"]): |
| return self.knowledge_base.get("你是谁", "我是TruDecide,你的智能市场分析助手。") |
| |
| |
| return random.choice(self.general_responses) |
|
|
| |
| if __name__ == "__main__": |
| model = TruDecide() |
| print(model.answer("什么是股票市场?")) |
| print(model.answer("你是谁?")) |
|
|