File size: 17,724 Bytes
5196d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
"""
Popescu–Farid CFA Consistency Analyzer for the agent.

This tool analyzes Color Filter Array (CFA) demosaicing artifacts to detect
inconsistencies within an image. It is designed for SPLICE DETECTION and
SOURCE CONSISTENCY analysis, NOT for whole-image authenticity classification.

Scientific basis:
- Real camera images have CFA interpolation artifacts from Bayer demosaicing
- Spliced regions from different sources (AI, screenshots, different cameras)
  may have different or absent CFA patterns
- By analyzing the DISTRIBUTION of CFA metrics across windows, we can identify
  regions that are inconsistent with the rest of the image

What this tool DOES:
- Detects CFA pattern consistency across image regions
- Identifies outlier windows that differ from the image baseline
- Provides distribution analysis (unimodal vs bimodal)

What this tool does NOT do:
- Classify whole images as "authentic" or "fake"
- Work reliably on heavily compressed images
- Detect AI-generated images (use TruFor for that)

Supports two modes:
- analyze: run CFA consistency analysis on a single image
- calibrate: optional; build reference thresholds from a set of camera images
"""

from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple

import numpy as np

# Ensure repo root (which contains example_tools) is on sys.path so we can load
# example_tools/cfa.py as a namespace package.
ROOT = Path(__file__).resolve().parents[3]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

try:
    from example_tools import cfa  # type: ignore
except Exception as exc:  # pragma: no cover - defensive import guard
    raise ImportError(
        "Unable to import example_tools.cfa. Ensure repository root is on sys.path."
    ) from exc


DEFAULT_PATTERN = "RGGB"
DEFAULT_WINDOW = 256
DEFAULT_TOP_K = 5
DEFAULT_OUTLIER_ZSCORE = 2.0  # Windows beyond this z-score are outliers


def _parse_request(input_str: str) -> Dict[str, Any]:
    """Parse JSON or treat input_str as image_path for analyze mode."""
    try:
        data = json.loads(input_str)
        if isinstance(data, dict):
            return data
        if isinstance(data, str):
            return {"mode": "analyze", "image_path": data}
    except Exception:
        pass
    return {"mode": "analyze", "image_path": input_str}


def _compute_stats(values: Sequence[float]) -> Dict[str, float]:
    """Compute basic statistics for a list of values."""
    arr = np.asarray(values, dtype=np.float64)
    if arr.size == 0:
        return {"min": 0.0, "max": 0.0, "mean": 0.0, "median": 0.0, "std": 0.0}
    return {
        "min": float(np.min(arr)),
        "max": float(np.max(arr)),
        "mean": float(np.mean(arr)),
        "median": float(np.median(arr)),
        "std": float(np.std(arr)),
    }


def _detect_bimodality(values: Sequence[float]) -> Dict[str, Any]:
    """
    Detect if the distribution of values is bimodal using Hartigan's dip test
    approximation and coefficient of bimodality.

    Returns:
        Dictionary with bimodality analysis results
    """
    arr = np.asarray(values, dtype=np.float64)
    if arr.size < 10:
        return {
            "is_bimodal": False,
            "bimodality_coefficient": 0.0,
            "distribution_type": "insufficient_data",
            "note": "Need at least 10 windows for distribution analysis",
        }

    # Compute bimodality coefficient: BC = (skewness^2 + 1) / kurtosis
    # BC > 0.555 suggests bimodality (Pfister et al., 2013)
    mean = np.mean(arr)
    std = np.std(arr)
    if std < 1e-10:
        return {
            "is_bimodal": False,
            "bimodality_coefficient": 0.0,
            "distribution_type": "constant",
            "note": "All values are nearly identical",
        }

    normalized = (arr - mean) / std
    skewness = float(np.mean(normalized ** 3))
    kurtosis = float(np.mean(normalized ** 4))

    # Excess kurtosis adjustment (Fisher's definition)
    excess_kurtosis = kurtosis - 3.0

    # Bimodality coefficient
    # For a uniform distribution: BC ≈ 0.555
    # For a bimodal distribution: BC > 0.555
    bc = (skewness ** 2 + 1) / (kurtosis + 3 * ((arr.size - 1) ** 2) / ((arr.size - 2) * (arr.size - 3)))
    bc = float(bc)

    # Also check coefficient of variation (CV) - high CV suggests mixed sources
    cv = std / mean if mean > 0 else 0.0

    # Determine distribution type
    if bc > 0.6:
        dist_type = "bimodal"
        is_bimodal = True
    elif bc > 0.5:
        dist_type = "possibly_bimodal"
        is_bimodal = False
    elif cv > 0.3:
        dist_type = "high_variance"
        is_bimodal = False
    else:
        dist_type = "unimodal"
        is_bimodal = False

    return {
        "is_bimodal": is_bimodal,
        "bimodality_coefficient": bc,
        "coefficient_of_variation": float(cv),
        "skewness": skewness,
        "excess_kurtosis": excess_kurtosis,
        "distribution_type": dist_type,
    }


def _find_outliers(
    values: Sequence[float],
    positions: Sequence[Tuple[int, int, int, int]],
    z_threshold: float = DEFAULT_OUTLIER_ZSCORE,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Find outlier windows based on z-score from median.

    Uses median and MAD (median absolute deviation) for robustness.

    Args:
        values: M values for each window
        positions: (y, x, h, w) for each window
        z_threshold: Z-score threshold for outlier detection

    Returns:
        Tuple of (low_outliers, high_outliers) - windows with unusually low/high M values
    """
    arr = np.asarray(values, dtype=np.float64)
    if arr.size < 5:
        return [], []

    median = float(np.median(arr))
    mad = float(np.median(np.abs(arr - median)))
    # Scale MAD to be consistent with std for normal distribution
    mad_scaled = mad * 1.4826 if mad > 0 else 1e-10

    low_outliers = []
    high_outliers = []

    for i, (val, pos) in enumerate(zip(values, positions)):
        z_score = (val - median) / mad_scaled
        if z_score < -z_threshold:
            # Distinguish between zero M (flat region) and low M (weak CFA)
            if val < 1e-6:
                interp = "Zero CFA signal - likely flat/uniform region (sky, wall) or synthetic"
            else:
                interp = "Weak CFA signal - possible splice, synthetic region, or heavy processing"
            low_outliers.append({
                "y": pos[0],
                "x": pos[1],
                "h": pos[2],
                "w": pos[3],
                "M_value": float(val),
                "z_score": float(z_score),
                "interpretation": interp,
            })
        elif z_score > z_threshold:
            high_outliers.append({
                "y": pos[0],
                "x": pos[1],
                "h": pos[2],
                "w": pos[3],
                "M_value": float(val),
                "z_score": float(z_score),
                "interpretation": "Unusually strong CFA - possible different camera source",
            })

    # Sort by absolute z-score (most anomalous first)
    low_outliers.sort(key=lambda x: x["z_score"])
    high_outliers.sort(key=lambda x: -x["z_score"])

    return low_outliers, high_outliers


def _classify_window_populations(
    values: Sequence[float],
) -> Dict[str, Any]:
    """
    Classify windows into populations based on M value magnitude.

    Real camera images typically show:
    - Low M (~0): Flat/uniform regions (sky, walls) - no texture to detect CFA
    - High M (>1e9): Textured regions with strong CFA signal

    This is content-dependent, not evidence of manipulation.
    Manipulation would show as textured regions WITHOUT CFA signal.
    """
    arr = np.asarray(values, dtype=np.float64)
    if arr.size == 0:
        return {"flat_regions": 0, "textured_regions": 0, "intermediate": 0}

    # Thresholds based on typical M value ranges
    # These are heuristic but based on observed values
    flat_threshold = 1e6  # Below this = flat region (no texture)
    textured_threshold = 1e9  # Above this = strong CFA signal

    flat_count = int(np.sum(arr < flat_threshold))
    textured_count = int(np.sum(arr >= textured_threshold))
    intermediate_count = len(arr) - flat_count - textured_count

    return {
        "flat_regions": flat_count,
        "textured_regions": textured_count,
        "intermediate": intermediate_count,
        "flat_pct": flat_count / len(arr) * 100,
        "textured_pct": textured_count / len(arr) * 100,
    }


def _window_brief(entry: Dict[str, Any], channel: str = "G") -> Dict[str, Any]:
    """Extract brief window info for a specific channel."""
    return {
        "y": entry["y"],
        "x": entry["x"],
        "h": entry["h"],
        "w": entry["w"],
        "M_value": float(entry[channel]["M"]),
    }


def _analyze(params: Dict[str, Any]) -> Dict[str, Any]:
    """Run CFA consistency analysis on a single image."""
    image_path = params.get("image_path")
    if not image_path:
        return {"error": "image_path is required for analyze mode."}

    window = int(params.get("window", DEFAULT_WINDOW))
    pattern = params.get("pattern", DEFAULT_PATTERN)
    em_kwargs = params.get("em") or params.get("em_kwargs") or {}
    top_k = int(params.get("top_k", DEFAULT_TOP_K))
    channel = params.get("channel", "G").upper()  # Green channel is most reliable

    if channel not in ("R", "G", "B"):
        channel = "G"

    try:
        img = cfa.load_rgb_image(str(image_path))
    except Exception as e:
        return {"error": f"Failed to load image: {e}"}

    try:
        window_results = cfa.analyze_image_windows(
            img, window=window, pattern=pattern, em_kwargs=em_kwargs
        )
    except Exception as e:
        return {"error": f"CFA analysis failed: {e}"}

    if not window_results:
        return {"error": "No windows analyzed (image may be too small)."}

    # Extract M values and positions for the selected channel
    m_values = [float(r[channel]["M"]) for r in window_results]
    positions = [(r["y"], r["x"], r["h"], r["w"]) for r in window_results]

    # Compute statistics
    stats = _compute_stats(m_values)

    # Analyze distribution
    bimodality = _detect_bimodality(m_values)

    # Classify windows into populations (flat vs textured)
    populations = _classify_window_populations(m_values)

    # Get top windows by M value (strongest CFA signal)
    sorted_indices = np.argsort(m_values)[::-1]
    top_windows = [
        _window_brief(window_results[i], channel)
        for i in sorted_indices[:top_k]
    ]

    # Get bottom windows by M value (weakest CFA signal)
    bottom_windows = [
        _window_brief(window_results[i], channel)
        for i in sorted_indices[-top_k:][::-1]
    ]

    # Determine if image has CFA signal at all
    has_cfa_signal = populations["textured_regions"] > 0
    textured_pct = populations["textured_pct"]

    # Generate interpretation based on content analysis
    if not has_cfa_signal:
        interpretation = (
            "No strong CFA signal detected in any region. "
            "This could indicate: (1) AI-generated image, (2) heavily processed image, "
            "(3) screenshot, or (4) image with only flat/uniform content."
        )
    elif textured_pct > 50:
        interpretation = (
            f"Strong CFA signal detected in {textured_pct:.0f}% of windows. "
            "Consistent with camera-captured image. Flat regions (sky, walls) "
            "naturally show weaker CFA signal due to lack of texture."
        )
    elif textured_pct > 20:
        interpretation = (
            f"CFA signal detected in {textured_pct:.0f}% of windows (textured regions). "
            "Remaining windows are flat/uniform regions where CFA cannot be detected. "
            "This distribution is normal for photos with sky or uniform backgrounds."
        )
    else:
        interpretation = (
            f"Weak CFA signal - only {textured_pct:.0f}% of windows show strong CFA. "
            "Image may be heavily processed, low-texture, or partially synthetic."
        )

    # Build result
    result: Dict[str, Any] = {
        "tool": "perform_cfa_detection",
        "status": "completed",
        "image_path": str(image_path),
        "analysis_channel": channel,
        "window_size": window,
        "window_count": len(window_results),
        "pattern": pattern,

        # Main output: population analysis
        "has_cfa_signal": has_cfa_signal,
        "interpretation": interpretation,

        # Window populations
        "window_populations": {
            "textured_with_cfa": populations["textured_regions"],
            "flat_no_texture": populations["flat_regions"],
            "intermediate": populations["intermediate"],
            "textured_pct": populations["textured_pct"],
            "flat_pct": populations["flat_pct"],
        },

        # Distribution analysis (for advanced users)
        "distribution": {
            "type": bimodality["distribution_type"],
            "is_bimodal": bimodality["is_bimodal"],
            "bimodality_coefficient": bimodality["bimodality_coefficient"],
            "note": "Bimodal distribution is NORMAL for photos with mixed content (sky + texture)",
        },

        # Statistics
        "m_value_stats": stats,

        # Reference windows
        "strongest_cfa_windows": top_windows,
        "weakest_cfa_windows": bottom_windows,

        "note": (
            "CFA analysis detects demosaicing artifacts from camera sensors. "
            "Flat regions (sky, walls) naturally have weak/no CFA signal. "
            "Look for TEXTURED regions with weak CFA - those may be spliced. "
            "This tool complements TruFor for localization, not whole-image classification."
        ),
    }

    return result


def _calibrate(params: Dict[str, Any]) -> Dict[str, Any]:
    """Calibrate reference statistics from a set of camera images."""
    neg_dir = params.get("neg_dir") or params.get("ref_dir")
    if not neg_dir:
        return {"error": "neg_dir (or ref_dir) is required for calibrate mode."}

    window = int(params.get("window", DEFAULT_WINDOW))
    pattern = params.get("pattern", DEFAULT_PATTERN)
    em_kwargs = params.get("em") or params.get("em_kwargs") or {}
    save_to = params.get("save_to") or params.get("output")

    neg_files = cfa.list_image_files(str(neg_dir))
    if not neg_files:
        return {"error": f"No images found in directory: {neg_dir}"}

    # Collect M values from all reference images
    all_m_values: Dict[str, List[float]] = {"R": [], "G": [], "B": []}

    for path in neg_files:
        try:
            img = cfa.load_rgb_image(str(path))
            window_results = cfa.analyze_image_windows(
                img, window=window, pattern=pattern, em_kwargs=em_kwargs
            )
            for r in window_results:
                all_m_values["R"].append(float(r["R"]["M"]))
                all_m_values["G"].append(float(r["G"]["M"]))
                all_m_values["B"].append(float(r["B"]["M"]))
        except Exception:
            continue  # Skip problematic images

    if not all_m_values["G"]:
        return {"error": "No valid windows collected from reference images."}

    # Compute reference statistics for each channel
    reference_stats = {
        c: _compute_stats(vals) for c, vals in all_m_values.items()
    }

    payload = {
        "reference_stats": reference_stats,
        "pattern": pattern,
        "window": window,
        "em_params": em_kwargs,
        "num_images": len(neg_files),
        "num_windows": len(all_m_values["G"]),
    }

    if save_to:
        Path(save_to).write_text(json.dumps(payload, indent=2), encoding="utf-8")
        payload["saved_to"] = str(save_to)

    return payload


def perform_cfa_detection(input_str: str) -> str:
    """
    LangChain tool entrypoint for CFA consistency analysis.

    This tool analyzes CFA (Color Filter Array) demosaicing patterns to detect
    INCONSISTENCIES within an image. It is designed for splice detection and
    source consistency analysis.

    Input (JSON):
      - mode: "analyze" (default) or "calibrate"
      - image_path: required for analyze
      - window: int (default 256)
      - pattern: Bayer pattern (default RGGB)
      - channel: which channel to analyze (default "G" - green is most reliable)
      - em / em_kwargs: dict for EM params (N, sigma0, p0, max_iter, tol, seed)
      - top_k: int (default 5) - number of top/outlier windows to return
      - outlier_zscore: float (default 2.0) - z-score threshold for outlier detection
      - neg_dir/ref_dir: required for calibrate mode
      - save_to: optional path to write reference stats JSON (calibrate)

    Output:
      - cfa_consistency_score: 0-1 score (higher = more consistent)
      - distribution: analysis of M value distribution (unimodal/bimodal)
      - outliers: windows with unusually low/high CFA patterns
      - interpretation: human-readable summary
    """
    params = _parse_request(input_str)
    mode = params.get("mode", "analyze").lower()

    # Support legacy "detect" mode name
    if mode == "detect":
        mode = "analyze"

    if mode == "calibrate":
        result = _calibrate(params)
    elif mode == "analyze":
        result = _analyze(params)
    else:
        result = {"error": "mode must be 'analyze' or 'calibrate'."}

    try:
        return json.dumps(result, indent=2)
    except Exception:
        # Fallback in case something is not JSON-serializable
        return json.dumps({"error": "Failed to serialize result."}, indent=2)


__all__ = ["perform_cfa_detection"]