| """検出結果の統合・スコアリング""" |
|
|
| import time |
| from typing import List, Optional, Dict |
| from dataclasses import dataclass, field |
| from collections import deque |
|
|
| from core.location_matcher import LocationCandidate, MatchResult |
| from utils.geo_utils import haversine_distance |
|
|
|
|
| @dataclass |
| class DetectionEvent: |
| """検出イベント""" |
|
|
| timestamp: float |
| ocr_texts: List[str] |
| vlm_keywords: List[str] |
| match_result: Optional[MatchResult] |
|
|
|
|
| @dataclass |
| class AggregatedResult: |
| """統合結果""" |
|
|
| estimated_lat: Optional[float] = None |
| estimated_lon: Optional[float] = None |
| confidence: float = 0.0 |
| address_hint: str = "" |
| detected_texts: List[str] = field(default_factory=list) |
| detected_landmarks: List[str] = field(default_factory=list) |
| match_count: int = 0 |
| is_location_found: bool = False |
|
|
|
|
| class ResultAggregator: |
| """ |
| 複数の検出結果を時間軸で統合し、 |
| 信頼度の高い位置推定を行う |
| """ |
|
|
| def __init__( |
| self, |
| buffer_size: int = 10, |
| confidence_threshold: float = 0.6, |
| consistency_window_sec: float = 10.0, |
| ): |
| self.buffer_size = buffer_size |
| self.confidence_threshold = confidence_threshold |
| self.consistency_window_sec = consistency_window_sec |
|
|
| self._events: deque = deque(maxlen=buffer_size) |
| self._detected_texts: Dict[str, int] = {} |
| self._candidate_history: List[LocationCandidate] = [] |
|
|
| def add_detection( |
| self, |
| ocr_texts: List[str], |
| vlm_keywords: List[str], |
| match_result: Optional[MatchResult], |
| ) -> None: |
| """検出イベントを追加""" |
| event = DetectionEvent( |
| timestamp=time.time(), |
| ocr_texts=ocr_texts, |
| vlm_keywords=vlm_keywords, |
| match_result=match_result, |
| ) |
| self._events.append(event) |
|
|
| |
| for text in ocr_texts: |
| self._detected_texts[text] = self._detected_texts.get(text, 0) + 1 |
|
|
| |
| if match_result and match_result.best_candidate: |
| self._candidate_history.append(match_result.best_candidate) |
| |
| if len(self._candidate_history) > self.buffer_size: |
| self._candidate_history = self._candidate_history[-self.buffer_size:] |
|
|
| def get_aggregated_result(self) -> AggregatedResult: |
| """統合結果を取得""" |
| if not self._events: |
| return AggregatedResult() |
|
|
| |
| frequent_texts = [ |
| text |
| for text, count in sorted( |
| self._detected_texts.items(), key=lambda x: x[1], reverse=True |
| ) |
| if count >= 2 |
| ][:10] |
|
|
| |
| vlm_keywords = set() |
| for event in self._events: |
| vlm_keywords.update(event.vlm_keywords) |
|
|
| |
| if not self._candidate_history: |
| return AggregatedResult( |
| detected_texts=frequent_texts, |
| detected_landmarks=list(vlm_keywords), |
| ) |
|
|
| |
| latest = self._candidate_history[-1] |
| consistent_candidates = [] |
|
|
| for candidate in self._candidate_history: |
| distance = haversine_distance( |
| latest.lat, latest.lon, candidate.lat, candidate.lon |
| ) |
| if distance < 100: |
| consistent_candidates.append(candidate) |
|
|
| |
| consistency_ratio = len(consistent_candidates) / len(self._candidate_history) |
| avg_score = sum(c.score for c in consistent_candidates) / max( |
| len(consistent_candidates), 1 |
| ) |
|
|
| |
| normalized_score = min(avg_score / 20, 1.0) |
| confidence = (consistency_ratio * 0.6 + normalized_score * 0.4) |
|
|
| is_found = ( |
| confidence >= self.confidence_threshold |
| and len(consistent_candidates) >= 2 |
| ) |
|
|
| |
| if consistent_candidates: |
| avg_lat = sum(c.lat for c in consistent_candidates) / len( |
| consistent_candidates |
| ) |
| avg_lon = sum(c.lon for c in consistent_candidates) / len( |
| consistent_candidates |
| ) |
| else: |
| avg_lat = latest.lat |
| avg_lon = latest.lon |
|
|
| |
| address_hint = self._generate_address_hint(consistent_candidates) |
|
|
| return AggregatedResult( |
| estimated_lat=avg_lat, |
| estimated_lon=avg_lon, |
| confidence=confidence, |
| address_hint=address_hint, |
| detected_texts=frequent_texts, |
| detected_landmarks=list(vlm_keywords), |
| match_count=len(self._candidate_history), |
| is_location_found=is_found, |
| ) |
|
|
| def _generate_address_hint( |
| self, candidates: List[LocationCandidate] |
| ) -> str: |
| """候補から住所ヒントを生成""" |
| if not candidates: |
| return "" |
|
|
| |
| landmarks = [] |
| for candidate in candidates: |
| for reason in candidate.match_reasons: |
| if "名前マッチ" in reason: |
| |
| name = reason.replace("名前マッチ: ", "") |
| if name not in landmarks: |
| landmarks.append(name) |
|
|
| if landmarks: |
| return f"{landmarks[0]}付近" |
| return "" |
|
|
| def reset(self) -> None: |
| """状態をリセット""" |
| self._events.clear() |
| self._detected_texts.clear() |
| self._candidate_history.clear() |
|
|
| def get_recent_texts(self, limit: int = 5) -> List[str]: |
| """最近検出されたテキストを取得""" |
| texts = [] |
| for event in reversed(list(self._events)): |
| for text in event.ocr_texts: |
| if text not in texts: |
| texts.append(text) |
| if len(texts) >= limit: |
| return texts |
| return texts |
|
|
| def get_detection_stats(self) -> Dict: |
| """検出統計を取得""" |
| return { |
| "event_count": len(self._events), |
| "unique_texts": len(self._detected_texts), |
| "candidate_count": len(self._candidate_history), |
| "top_texts": sorted( |
| self._detected_texts.items(), key=lambda x: x[1], reverse=True |
| )[:5], |
| } |
|
|