File size: 3,768 Bytes
2758540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
nlp/qa.py - Question Answering over Surveillance Logs using deepset/roberta-base-squad2
"""
import time
from typing import Optional, Dict, List
from transformers import pipeline, Pipeline
from loguru import logger
from config import settings, DEVICE


class SurveillanceQA:
    """
    Extractive QA system. Given a question and a context built from
    surveillance logs/events, extracts the most relevant answer span.
    """

    def __init__(self):
        logger.info(f"Loading QA model: {settings.QA_MODEL}")
        device_id = 0 if str(DEVICE) == "cuda" else -1
        self.qa_pipeline: Pipeline = pipeline(
            "question-answering",
            model=settings.QA_MODEL,
            tokenizer=settings.QA_MODEL,
            device=device_id,
        )
        logger.info("✅ SurveillanceQA ready.")

    def _build_context(self, events: List[Dict]) -> str:
        """Build a natural language context string from event records."""
        lines = []
        for e in events:
            ts = e.get("timestamp", "unknown time")
            cam = e.get("camera_id", "unknown camera")
            activity = e.get("activity_type", "detected")
            person_id = str(e.get("person_id", "unknown"))[:8]
            attrs = e.get("attributes", {})
            desc = e.get("description", "")
            attr_str = ""
            if attrs:
                gender = attrs.get("gender", "")
                color = attrs.get("color", "")
                clothing = ", ".join([c.get("label", "") for c in attrs.get("clothing", [])[:2]])
                attr_str = f"({gender}, {color} clothing, {clothing})"
            line = f"At {ts}, camera {cam} detected person {person_id} {attr_str} with activity: {activity}."
            if desc:
                line += f" {desc}"
            lines.append(line)
        return " ".join(lines)

    def answer(
        self,
        question: str,
        events: Optional[List[Dict]] = None,
        context: Optional[str] = None,
        top_k: int = 3,
    ) -> Dict:
        """
        Answer a natural language question about surveillance data.

        Args:
            question: User's question
            events: List of event dicts (auto-builds context)
            context: Pre-built context string
            top_k: Number of answer candidates

        Returns:
            {"answer": str, "score": float, "start": int, "end": int, "context": str, "latency_ms": float}
        """
        if context is None:
            if not events:
                return {"answer": "No surveillance data available to answer from.", "score": 0.0}
            context = self._build_context(events)

        if not context.strip():
            return {"answer": "No context available.", "score": 0.0}

        # Truncate context to model max (512 tokens ≈ ~2000 chars)
        context = context[:4000]

        t0 = time.perf_counter()
        result = self.qa_pipeline(
            question=question,
            context=context,
            top_k=top_k,
            handle_impossible_answer=True,
        )
        latency_ms = (time.perf_counter() - t0) * 1000

        if isinstance(result, list):
            best = result[0]
        else:
            best = result

        logger.debug(f"QA answered '{question[:50]}' in {latency_ms:.1f}ms | score={best.get('score', 0):.3f}")
        return {
            "answer": best.get("answer", ""),
            "score": round(best.get("score", 0.0), 4),
            "start": best.get("start", 0),
            "end": best.get("end", 0),
            "context_used": context[:500] + "..." if len(context) > 500 else context,
            "latency_ms": round(latency_ms, 2),
            "all_answers": result if isinstance(result, list) else [result],
        }