kofdai commited on
Commit
4c1ab44
·
verified ·
1 Parent(s): 5903ccf

Upload runner_engine.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. runner_engine.py +139 -0
runner_engine.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # runner_engine.py
2
+
3
+ import asyncio
4
+ from hot_cache import LRUCache
5
+ from web_search_autonomy import WebSearchAutonomySystem
6
+
7
+ # NOTE: The mock objects previously in this file have been moved to `mock_objects.py`
8
+ # for centralized test management. The main `RunnerEngine` class below is the
9
+ # actual implementation.
10
+
11
+
12
+
13
+ class RunnerEngine:
14
+ """推論実行エンジン(Layer 3)"""
15
+
16
+ def __init__(self, llm_client, db_interface, web_search_system):
17
+ self.llm = llm_client
18
+ self.db = db_interface
19
+ self.web_search_system = web_search_system
20
+ self.hot_cache = LRUCache(max_size=20)
21
+
22
+ async def _fetch_db_coordinates(self, db_coordinates: list) -> dict:
23
+ """DB座標から知識を取得(ホットキャッシュ利用)"""
24
+ results = {}
25
+ for coord in db_coordinates[:5]:
26
+ if coord in self.hot_cache:
27
+ print(f"Cache: Hit for {coord}")
28
+ results[coord] = self.hot_cache[coord]
29
+ continue
30
+
31
+ print(f"Cache: Miss for {coord}")
32
+ tile = await self.db.fetch_async(coord)
33
+ if tile:
34
+ self.hot_cache[coord] = tile
35
+ results[coord] = tile
36
+ return results
37
+
38
+ def _build_context(self, question: str, db_results: dict, session_context) -> str:
39
+ """LLMプロンプト用のコンテキストを構築"""
40
+ context_parts = []
41
+ if session_context: # この例では未使用
42
+ context_parts.append(f"セッション履歴: {session_context}")
43
+
44
+ if db_results:
45
+ for coord, tile in db_results.items():
46
+ context_parts.append(f"【確実性{tile['certainty']}%】{tile['content']}")
47
+
48
+ return "\n\n".join(context_parts)
49
+
50
+ def _format_prompt(self, question: str, context: str) -> str:
51
+ return f"情報: {context}\n\n質問: {question}\n\n指示: 提供された情報に基づき回答してください。"
52
+
53
+ async def generate_response_streaming(self, question: str, db_coordinates: list, session_context=None):
54
+ """ストリーミング形式での回答生成と動的なWeb検索判断"""
55
+ web_decision = self.web_search_system.should_search(question)
56
+
57
+ db_task = asyncio.create_task(self._fetch_db_coordinates(db_coordinates))
58
+ web_task = asyncio.create_task(mock_web_search_api(question)) if web_decision["should_search"] else None
59
+
60
+ try:
61
+ db_results = await asyncio.wait_for(db_task, timeout=0.5)
62
+ except asyncio.TimeoutError:
63
+ db_results = {}
64
+
65
+ context = self._build_context(question, db_results, session_context)
66
+ prompt = self._format_prompt(question, context)
67
+
68
+ partial_response = ""
69
+ final_metadata = {}
70
+
71
+ async for result in self.llm.generate_streaming(prompt):
72
+ if result['type'] == 'response_token':
73
+ token = result['token']
74
+ partial_response += token
75
+ yield result # トークンをそのまま中継
76
+
77
+ # 推論中の動的Web検索判定
78
+ if len(partial_response) > 5 and len(partial_response) % 20 == 0 and not web_task:
79
+ class MockInferenceState: partial_response = ""
80
+ inference_state = MockInferenceState()
81
+ inference_state.partial_response = partial_response
82
+ dynamic_decision = self.web_search_system.should_search(question, inference_state=inference_state)
83
+ if dynamic_decision["should_search"]:
84
+ print("\n*** Dynamic Web Search Triggered! ***\n")
85
+ web_task = asyncio.create_task(mock_web_search_api(question))
86
+
87
+ elif result['type'] == 'completion':
88
+ # Judge層で必要となる構造化されたメタデータを準備
89
+ final_metadata = result['metadata']
90
+
91
+ web_results_content = []
92
+ if web_task:
93
+ try:
94
+ web_results_content = await asyncio.wait_for(web_task, timeout=2.0)
95
+ yield {"type": "web_results", "results": web_results_content}
96
+ except asyncio.TimeoutError:
97
+ yield {"type": "web_results", "results": [], "error": "timeout"}
98
+
99
+ # 最終的なメタデータを生成して終了
100
+ final_metadata["referenced_coords"] = db_coordinates
101
+ final_metadata["web_results"] = web_results_content
102
+ yield {
103
+ "type": "final_structured_response",
104
+ "is_complete": True,
105
+ "main_response": partial_response,
106
+ **final_metadata # thinking_process, key_pointsなどを展開
107
+ }
108
+
109
+ # --- 実行例 ---
110
+ async def main():
111
+ # モックコンポーネントの初期化
112
+ llm = MockLLMClient()
113
+ db = MockDBInterface()
114
+ web_search = WebSearchAutonomySystem()
115
+
116
+ runner = RunnerEngine(llm, db, web_search)
117
+
118
+ question = "最新の心筋梗塞の診断について"
119
+ # Layer 1で抽出された想定の座標
120
+ db_coordinates = [(28, 35, 15)]
121
+
122
+ print(f"--- Running pipeline for question: '{question}' ---")
123
+ final_response = {}
124
+ async for event in runner.generate_response_streaming(question, db_coordinates):
125
+ if event['type'] == 'response_token':
126
+ print(event['token'], end='', flush=True)
127
+ elif event['type'] == 'web_results':
128
+ print(f"\n\n--- Web Results Received ---")
129
+ print(event['results'])
130
+ elif event['type'] == 'final_structured_response':
131
+ final_response = event
132
+
133
+ print("\n\n--- Final Structured Response (for Judge Layer) ---")
134
+ import json
135
+ print(json.dumps(final_response, indent=2, ensure_ascii=False))
136
+
137
+
138
+ if __name__ == "__main__":
139
+ asyncio.run(main())