File size: 3,860 Bytes
ce4fddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Image preprocessing utilities for medical image analysis."""

from __future__ import annotations
import numpy as np
import cv2
from PIL import Image
import io
import torch


def load_image_from_bytes(file_bytes: bytes) -> np.ndarray:
    """Load an image from raw bytes into a numpy array (BGR).

    Supports common image formats through OpenCV and DICOM through pydicom.
    """
    nparr = np.frombuffer(file_bytes, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    if img is not None:
        return img

    return load_dicom_from_bytes(file_bytes)


def load_dicom_from_bytes(file_bytes: bytes) -> np.ndarray:
    """Decode DICOM pixel data into an 8-bit BGR image."""
    try:
        import pydicom
    except ImportError as exc:
        raise ValueError("DICOM support requires pydicom to be installed.") from exc

    try:
        ds = pydicom.dcmread(io.BytesIO(file_bytes), force=True)
        arr = ds.pixel_array.astype(np.float32)
    except Exception as exc:
        raise ValueError("Could not decode the uploaded image or DICOM file.") from exc

    slope = float(getattr(ds, "RescaleSlope", 1.0))
    intercept = float(getattr(ds, "RescaleIntercept", 0.0))
    arr = arr * slope + intercept

    photo = str(getattr(ds, "PhotometricInterpretation", "")).upper()
    if photo == "MONOCHROME1":
        arr = arr.max() - arr

    arr = arr - arr.min()
    max_val = arr.max()
    if max_val > 0:
        arr = arr / max_val
    img8 = (arr * 255).clip(0, 255).astype(np.uint8)
    return cv2.cvtColor(img8, cv2.COLOR_GRAY2BGR)


def preprocess_for_chest(file_bytes: bytes) -> np.ndarray:
    """Preprocess an image for TorchXRayVision DenseNet121.

    Returns a normalized 224×224 grayscale image as a numpy array
    with shape (1, 1, 224, 224) ready for model input.
    """
    img = load_image_from_bytes(file_bytes)

    # Convert to grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Apply CLAHE for contrast enhancement (important for X-rays)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(gray)

    # Resize to 224x224
    resized = cv2.resize(enhanced, (224, 224), interpolation=cv2.INTER_AREA)

    # Normalize to [0, 1] then scale to [-1024, 1024] as TorchXRayVision expects
    normalized = resized.astype(np.float32)
    # TorchXRayVision expects images in range [-1024, 1024]
    normalized = (normalized / 255.0) * 2048.0 - 1024.0

    # Add batch and channel dimensions: (1, 1, 224, 224)
    tensor_input = normalized[np.newaxis, np.newaxis, :, :]
    return tensor_input


def preprocess_for_yolo(file_bytes: bytes) -> np.ndarray:
    """Preprocess an image for YOLOv8 inference.

    Returns the image as a numpy array (BGR) at its original size.
    YOLO handles its own resizing internally.
    """
    img = load_image_from_bytes(file_bytes)

    # Apply CLAHE on grayscale channel for better contrast
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l_channel, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    l_enhanced = clahe.apply(l_channel)
    enhanced = cv2.merge([l_enhanced, a, b])
    result = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)

    return result


def preprocess_for_vit(file_bytes: bytes) -> Image.Image:
    """Preprocess an image for ViT classification.

    Returns a PIL Image (RGB) — the ViT processor handles resizing/normalization.
    """
    try:
        return Image.open(io.BytesIO(file_bytes)).convert("RGB")
    except Exception:
        bgr = load_dicom_from_bytes(file_bytes)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        return Image.fromarray(rgb)


def image_to_base64(file_bytes: bytes) -> str:
    """Convert image bytes to a base64-encoded string for multimodal APIs."""
    import base64
    return base64.b64encode(file_bytes).decode("utf-8")