nullai-knowledge-system / coordinate_estimator.py
kofdai's picture
Upload folder using huggingface_hub
5af8123 verified
# null_ai/coordinate_estimator.py
"""
Coordinate Auto-Estimation Module
AIを使って知識タイルの6次元座標を自動推定します。
座標: [x, y, z, c, g, v]
- medical_space [x, y, z]: ドメイン固有の3次元空間
- meta_space [c, g, v]: Certainty, Granularity, Verification
"""
import logging
import json
from typing import List, Dict, Any, Optional
import asyncio
logger = logging.getLogger(__name__)
class CoordinateEstimator:
"""
LLMを使って6次元座標を自動推定するクラス
"""
def __init__(self):
self.domain_schemas = self._load_domain_schemas()
def _load_domain_schemas(self) -> Dict[str, Dict[str, str]]:
"""
各ドメインの座標軸の定義を返す
将来的には設定ファイルから読み込む
"""
return {
"medical": {
"x": "Anatomical location (0.0=nervous system, 0.5=cardiovascular, 1.0=digestive)",
"y": "Pathological classification (0.0=infectious, 0.5=metabolic, 1.0=trauma)",
"z": "Treatment level (0.0=prevention, 0.5=diagnosis, 1.0=treatment)"
},
"general": {
"x": "Knowledge category (0.0=science, 0.5=technology, 1.0=humanities)",
"y": "Complexity level (0.0=basic, 0.5=intermediate, 1.0=advanced)",
"z": "Application scope (0.0=theoretical, 0.5=practical, 1.0=applied)"
},
"legal": {
"x": "Legal field (0.0=civil, 0.5=criminal, 1.0=commercial)",
"y": "Court level (0.0=district, 0.5=high, 1.0=supreme)",
"z": "Era (0.0=classical, 0.5=modern, 1.0=contemporary)"
},
"technology": {
"x": "Technology domain (0.0=hardware, 0.5=software, 1.0=network)",
"y": "Maturity (0.0=emerging, 0.5=established, 1.0=legacy)",
"z": "Scale (0.0=personal, 0.5=enterprise, 1.0=global)"
}
}
async def estimate_coordinates(
self,
prompt: str,
response: str,
domain_id: str,
llm_inference_func,
use_reasoning: bool = True
) -> Dict[str, Any]:
"""
6次元座標を推定
Args:
prompt: ユーザーの質問
response: AIの回答
domain_id: ドメインID
llm_inference_func: LLM推論関数(async)
use_reasoning: 推論過程を含めるか
Returns:
{
"coordinates": [x, y, z, c, g, v],
"reasoning": "推定の理由",
"confidence": 0.85
}
"""
# ドメインスキーマ取得
domain_schema = self.domain_schemas.get(
domain_id,
self.domain_schemas["general"] # フォールバック
)
# プロンプト構築
estimation_prompt = self._build_estimation_prompt(
prompt, response, domain_id, domain_schema, use_reasoning
)
# LLMに座標推定を依頼
try:
llm_response = await llm_inference_func(estimation_prompt)
# レスポンスから座標を抽出
result = self._parse_llm_response(llm_response)
# バリデーション
if self._validate_coordinates(result["coordinates"]):
logger.info(f"Estimated coordinates for domain '{domain_id}': {result['coordinates']}")
return result
else:
logger.error(f"Invalid coordinates: {result['coordinates']}")
return self._get_default_coordinates(domain_id)
except Exception as e:
logger.error(f"Coordinate estimation failed: {e}")
return self._get_default_coordinates(domain_id)
def _build_estimation_prompt(
self,
prompt: str,
response: str,
domain_id: str,
domain_schema: Dict[str, str],
use_reasoning: bool
) -> str:
"""
座標推定用のプロンプトを構築
"""
base_prompt = f"""You are an expert in knowledge space mapping and coordinate estimation.
Your task is to estimate the 6-dimensional coordinates that best represent the following knowledge in the domain of "{domain_id}".
**Coordinate System:**
1. **Domain-specific space [x, y, z]** (each 0.0-1.0):
- x-axis: {domain_schema['x']}
- y-axis: {domain_schema['y']}
- z-axis: {domain_schema['z']}
2. **Meta-information space [c, g, v]** (each 0.0-1.0):
- c (Certainty): How certain/verified is this knowledge?
* 0.0 = hypothesis, speculation
* 0.5 = established theory, widely accepted
* 1.0 = proven fact, empirically verified
- g (Granularity): How detailed/specific is this knowledge?
* 0.0 = high-level overview, general concept
* 0.5 = detailed explanation
* 1.0 = highly specialized, expert-level detail
- v (Verification): What is the verification status?
* 0.0 = unverified, no sources
* 0.5 = expert-reviewed, single source
* 1.0 = peer-reviewed, multiple sources confirmed
**Knowledge to estimate:**
Question: {prompt}
Answer: {response}
**Instructions:**
"""
if use_reasoning:
base_prompt += """
1. First, analyze the knowledge and explain your reasoning for each coordinate.
2. Then, output the final coordinates.
Format your response as JSON:
{
"reasoning": "Your detailed reasoning here...",
"coordinates": [x, y, z, c, g, v],
"confidence": 0.85
}
"""
else:
base_prompt += """
Output ONLY the coordinates as a JSON object:
{
"coordinates": [x, y, z, c, g, v],
"confidence": 0.85
}
"""
base_prompt += """
**Important:**
- All coordinates must be between 0.0 and 1.0
- Use 2 decimal places (e.g., 0.75)
- confidence should reflect how confident you are in this estimation (0.0-1.0)
"""
return base_prompt
def _parse_llm_response(self, llm_response: str) -> Dict[str, Any]:
"""
LLMのレスポンスから座標を抽出
"""
try:
# JSONブロックを探す
# LLMはしばしば ```json ... ``` で囲む
if "```json" in llm_response:
json_start = llm_response.find("```json") + 7
json_end = llm_response.find("```", json_start)
json_str = llm_response[json_start:json_end].strip()
elif "```" in llm_response:
json_start = llm_response.find("```") + 3
json_end = llm_response.find("```", json_start)
json_str = llm_response[json_start:json_end].strip()
else:
# JSON全体を探す
json_str = llm_response.strip()
# JSONパース
result = json.loads(json_str)
# 必須フィールドチェック
if "coordinates" not in result:
raise ValueError("Missing 'coordinates' field")
# デフォルト値設定
if "reasoning" not in result:
result["reasoning"] = "No reasoning provided"
if "confidence" not in result:
result["confidence"] = 0.5
return result
except json.JSONDecodeError as e:
logger.error(f"JSON parse error: {e}")
logger.debug(f"LLM response: {llm_response}")
# フォールバック: 数値のリストを直接探す
return self._fallback_parse(llm_response)
def _fallback_parse(self, llm_response: str) -> Dict[str, Any]:
"""
JSONパースに失敗した場合のフォールバック
"""
import re
# [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] のようなパターンを探す
pattern = r'\[[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*\]'
match = re.search(pattern, llm_response)
if match:
coords = [float(match.group(i)) for i in range(1, 7)]
return {
"coordinates": coords,
"reasoning": "Parsed from array notation",
"confidence": 0.5
}
# パースに完全に失敗
raise ValueError("Could not parse coordinates from LLM response")
def _validate_coordinates(self, coordinates: List[float]) -> bool:
"""
座標の妥当性をチェック
"""
if not isinstance(coordinates, list):
return False
if len(coordinates) != 6:
logger.error(f"Expected 6 coordinates, got {len(coordinates)}")
return False
for i, coord in enumerate(coordinates):
if not isinstance(coord, (int, float)):
logger.error(f"Coordinate {i} is not a number: {coord}")
return False
if not (0.0 <= coord <= 1.0):
logger.error(f"Coordinate {i} out of range [0.0, 1.0]: {coord}")
return False
return True
def _get_default_coordinates(self, domain_id: str) -> Dict[str, Any]:
"""
推定に失敗した場合のデフォルト座標
"""
logger.warning(f"Using default coordinates for domain '{domain_id}'")
# ドメイン中心の座標
return {
"coordinates": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
"reasoning": "Default coordinates (estimation failed)",
"confidence": 0.3
}
async def estimate_batch(
self,
knowledge_items: List[Dict[str, str]],
llm_inference_func,
max_concurrent: int = 3
) -> List[Dict[str, Any]]:
"""
複数の知識アイテムの座標を一括推定
Args:
knowledge_items: [{"prompt": "...", "response": "...", "domain_id": "..."}, ...]
llm_inference_func: LLM推論関数
max_concurrent: 同時実行数
Returns:
推定結果のリスト
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def estimate_with_semaphore(item):
async with semaphore:
return await self.estimate_coordinates(
prompt=item["prompt"],
response=item["response"],
domain_id=item.get("domain_id", "general"),
llm_inference_func=llm_inference_func
)
tasks = [estimate_with_semaphore(item) for item in knowledge_items]
results = await asyncio.gather(*tasks)
return results
def get_domain_schema(self, domain_id: str) -> Dict[str, str]:
"""
ドメインスキーマを取得(UI表示用)
"""
return self.domain_schemas.get(domain_id, self.domain_schemas["general"])
def add_domain_schema(self, domain_id: str, schema: Dict[str, str]):
"""
新しいドメインスキーマを追加
"""
if not all(key in schema for key in ["x", "y", "z"]):
raise ValueError("Schema must contain 'x', 'y', 'z' definitions")
self.domain_schemas[domain_id] = schema
logger.info(f"Added domain schema for '{domain_id}'")
def interpolate_coordinates(
self,
coord1: List[float],
coord2: List[float],
weight: float = 0.5
) -> List[float]:
"""
2つの座標の間を補間(類似知識の座標推定に使用)
Args:
coord1: 座標1
coord2: 座標2
weight: 補間ウェイト (0.0=coord1, 1.0=coord2)
Returns:
補間された座標
"""
if len(coord1) != 6 or len(coord2) != 6:
raise ValueError("Both coordinates must be 6-dimensional")
interpolated = [
coord1[i] * (1 - weight) + coord2[i] * weight
for i in range(6)
]
return interpolated