|
|
import mlx.core as mx |
|
|
from mlx_lm import load, generate |
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
from transformers import pipeline |
|
|
from collections import deque |
|
|
from janome.tokenizer import Tokenizer |
|
|
import json |
|
|
import sqlite3 |
|
|
from datetime import datetime |
|
|
|
|
|
class IlmApp: |
|
|
|
|
|
BASE_PROMPT = """### 指示: |
|
|
あなたは、文脈を理解し、自然な応答を生成するAIアシスタントです。 |
|
|
以下の状況を考慮して、最適な応答を生成してください。 |
|
|
{intent_instruction} |
|
|
{style_instruction} |
|
|
{transition_instruction} |
|
|
### ユーザーからの入力: |
|
|
{user_input} |
|
|
|
|
|
### 応答:""" |
|
|
FLASHBACK_PROMPT = """... (省略) ...""" |
|
|
TRANSITION_CONTEXT = """### 直前の会話のトピック: |
|
|
{previous_response} |
|
|
""" |
|
|
|
|
|
|
|
|
INTENT_INSTRUCTIONS = {"質問": "...", "アイデアの要求": "...", "感想": "...", "雑談": "...", "デフォルト": "..."} |
|
|
STYLE_INSTRUCTIONS = {"丁寧": "...", "簡潔": "...", "創造的": "...", "ユーモラス": "...", "デフォルト": "..."} |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.db_path = "experience.db" |
|
|
self.llm_model_path = "./merged_model" |
|
|
self.sentence_model_name = "paraphrase-multilingual-MiniLM-L12-v2" |
|
|
self.classifier_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" |
|
|
self.intent_labels = ["質問", "アイデアの要求", "感想", "雑談"] |
|
|
self.style_labels = ["丁寧", "簡潔", "創造的", "ユーモラス"] |
|
|
self.topic_labels = ["テクノロジー", "ビジネス", "健康", "芸術", "食事", "地理", "歴史", "科学"] |
|
|
self.similarity_threshold = 0.4 |
|
|
self.common_transition_threshold = 5 |
|
|
|
|
|
|
|
|
self.flashback_buffer = deque(maxlen=5) |
|
|
self.previous_response = None |
|
|
|
|
|
|
|
|
self._init_db() |
|
|
self._load_models() |
|
|
|
|
|
def _init_db(self): |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS topic_transitions ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
source_topic TEXT NOT NULL, |
|
|
destination_topic TEXT NOT NULL, |
|
|
count INTEGER NOT NULL DEFAULT 1, |
|
|
last_occurred TIMESTAMP NOT NULL, |
|
|
UNIQUE(source_topic, destination_topic) |
|
|
) |
|
|
""") |
|
|
conn.commit() |
|
|
|
|
|
def _load_models(self): |
|
|
print("各モデルとツールを読み込んでいます...") |
|
|
self.model, self.tokenizer = load(self.llm_model_path) |
|
|
self.sentence_model = SentenceTransformer(self.sentence_model_name) |
|
|
self.classifier = pipeline("zero-shot-classification", model=self.classifier_name) |
|
|
self.janome_tokenizer = Tokenizer() |
|
|
print("モデルの読み込みが完了しました。") |
|
|
|
|
|
def _get_transition_experience(self, source, dest): |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT count FROM topic_transitions WHERE source_topic = ? AND destination_topic = ?", (source, dest)) |
|
|
result = cursor.fetchone() |
|
|
return result[0] if result else 0 |
|
|
|
|
|
def _update_l2_memory(self, source, dest): |
|
|
timestamp = datetime.now().isoformat() |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
INSERT INTO topic_transitions (source_topic, destination_topic, last_occurred) VALUES (?, ?, ?) |
|
|
ON CONFLICT(source_topic, destination_topic) DO UPDATE SET count = count + 1, last_occurred = excluded.last_occurred |
|
|
""", (source, dest, timestamp)) |
|
|
conn.commit() |
|
|
|
|
|
def _build_prompt(self, user_input): |
|
|
|
|
|
for item in self.flashback_buffer: |
|
|
if item['keyword'] in user_input: |
|
|
print(f"(フラッシュバックを検知: {item['keyword']})", end="") |
|
|
return self.FLASHBACK_PROMPT.format(keyword=item['keyword'], original_sentence=item['sentence'], user_input=user_input) |
|
|
|
|
|
|
|
|
intent = self.classifier(user_input, self.intent_labels, multi_label=False)['labels'][0] |
|
|
style = self.classifier(user_input, self.style_labels, multi_label=False)['labels'][0] |
|
|
print(f"(意図: {intent} | スタイル: {style})", end="") |
|
|
|
|
|
intent_instruction = f"\n### ユーザーの意図: {intent}\n{self.INTENT_INSTRUCTIONS.get(intent, self.INTENT_INSTRUCTIONS['デフォルト'])}" |
|
|
style_instruction = f"\n### ユーザーの対話スタイル: {style}\n{self.STYLE_INSTRUCTIONS.get(style, self.STYLE_INSTRUCTIONS['デフォルト'])}" |
|
|
transition_instruction = "" |
|
|
|
|
|
|
|
|
if self.previous_response: |
|
|
emb_prev = self.sentence_model.encode(self.previous_response, convert_to_tensor=True) |
|
|
emb_curr = self.sentence_model.encode(user_input, convert_to_tensor=True) |
|
|
sim = util.pytorch_cos_sim(emb_prev, emb_curr).item() |
|
|
print(f" (類似度: {sim:.2f})", end="") |
|
|
|
|
|
if sim < self.similarity_threshold: |
|
|
source_topic = self.classifier(self.previous_response, self.topic_labels, multi_label=False)['labels'][0] |
|
|
dest_topic = self.classifier(user_input, self.topic_labels, multi_label=False)['labels'][0] |
|
|
|
|
|
|
|
|
experience_count = self._get_transition_experience(source_topic, dest_topic) |
|
|
|
|
|
if experience_count > self.common_transition_threshold: |
|
|
transition_judgment = f"これは過去に{experience_count}回経験した、自然な話題の遷移です。その流れを汲み取って応答してください。" |
|
|
else: |
|
|
transition_judgment = f"これは斬新な話題の飛躍です。その面白さに触れつつ、応答を返してください。" |
|
|
|
|
|
print(f" (L2判断: {transition_judgment})") |
|
|
transition_instruction = f"\n### 話題遷移の分析:\n{transition_judgment}\nスムーズな移行文を生成してください。\n" + self.TRANSITION_CONTEXT.format(previous_response=self.previous_response) |
|
|
|
|
|
|
|
|
if source_topic != dest_topic: |
|
|
self._update_l2_memory(source_topic, dest_topic) |
|
|
else: |
|
|
print() |
|
|
|
|
|
return self.BASE_PROMPT.format(intent_instruction=intent_instruction, style_instruction=style_instruction, transition_instruction=transition_instruction, user_input=user_input) |
|
|
|
|
|
def _update_memory(self, response): |
|
|
self.previous_response = response |
|
|
try: |
|
|
for token in self.janome_tokenizer.tokenize(response): |
|
|
if token.part_of_speech.startswith('名詞'): |
|
|
self.flashback_buffer.append({'keyword': token.surface, 'sentence': response}); break |
|
|
except Exception: pass |
|
|
|
|
|
def run(self): |
|
|
print("\nIlmチャットを開始します。終了するには 'exit' と入力してください。") |
|
|
while True: |
|
|
user_input = input("\nあなた: ") |
|
|
if user_input.lower() == 'exit': break |
|
|
prompt = self._build_prompt(user_input) |
|
|
print("\nIlm: ", end="", flush=True) |
|
|
current_response = "" |
|
|
for token in generate(self.model, self.tokenizer, prompt=prompt, verbose=False): |
|
|
current_response += token |
|
|
print(token, end="", flush=True) |
|
|
print() |
|
|
self._update_memory(current_response) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
IlmApp.FLASHBACK_PROMPT = """### 指示: |
|
|
あなたは、以前の会話の断片を思い出すことができます。 |
|
|
ユーザーが、あなたが以前言及したキーワード「{keyword}」に触れました。 |
|
|
あなたはそのキーワードについて「{original_sentence}」と発言しています。 |
|
|
この「記憶の断片」を思い出したかのように自然な前置きを述べてから、ユーザーの現在の入力に答えてください。 |
|
|
|
|
|
### ユーザーからの現在の入力: |
|
|
{user_input} |
|
|
|
|
|
### 応答:""" |
|
|
IlmApp.INTENT_INSTRUCTIONS = { |
|
|
"質問": "ユーザーは具体的な情報を求めています。明確かつ簡潔に回答してください。", |
|
|
"アイデアの要求": "ユーザーは創造的な発想を求めています。斬新で多様なアイデアを提案してください。", |
|
|
"感想": "ユーザーは共感を求めています。同意や補足情報を提供し、会話を広げてください。", |
|
|
"雑談": "ユーザーは気軽な対話を望んでいます。親しみやすいトーンで応答してください。", |
|
|
"デフォルト": "ユーザーの入力に対して、適切に応答してください。" |
|
|
} |
|
|
IlmApp.STYLE_INSTRUCTIONS = { |
|
|
"丁寧": "ユーザーは丁寧な言葉遣いを好みます。あなたも敬体(です・ます調)で応答してください。", |
|
|
"簡潔": "ユーザーは要点をまとめて話しています。あなたも簡潔に応答してください。", |
|
|
"創造的": "ユーザーは比喩や創造的な表現を使っています。あなたも表現を工夫して応答してください。", |
|
|
"ユーモラス": "ユーザーはユーモアを交えて話しています。あなたも遊び心のある応答をしてください。", |
|
|
"デフォルト": "ユーザーのスタイルに合わせて、自然に応答してください。" |
|
|
} |
|
|
app = IlmApp() |
|
|
app.run() |
|
|
|