File size: 3,233 Bytes
2f9e969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""

processor.py

Converts uploaded files into PIL Images for the model.



preprocess_image(file)  →  PIL.Image (single image)

extract_frames(file)    →  list[PIL.Image]  (up to MAX_FRAMES from a video)



Both functions are async — called directly from FastAPI route handlers.

"""

import os
import tempfile

import cv2
import numpy as np
from fastapi import HTTPException, UploadFile
from PIL import Image

MAX_FRAMES = 20  # frames sampled per video — more = better accuracy, slower

ALLOWED_IMAGE = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
ALLOWED_VIDEO = {".mp4", ".avi", ".mov", ".mkv", ".webm"}


def _ext(filename: str) -> str:
    return os.path.splitext(filename)[-1].lower()


def _check_ext(filename: str, allowed: set):
    ext = _ext(filename)
    if ext not in allowed:
        raise HTTPException(
            status_code=400,
            detail=f"Unsupported file type '{ext}'. Allowed: {', '.join(sorted(allowed))}",
        )


def _bgr_to_pil(bgr: np.ndarray) -> Image.Image:
    return Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))


async def preprocess_image(file: UploadFile) -> Image.Image:
    """

    Read an uploaded image and return an RGB PIL Image.

    The HuggingFace ViTImageProcessor handles all resizing/normalisation.

    """
    _check_ext(file.filename, ALLOWED_IMAGE)

    raw    = await file.read()
    np_arr = np.frombuffer(raw, np.uint8)
    frame  = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)

    if frame is None:
        raise HTTPException(status_code=400, detail="Cannot decode image — file may be corrupted.")

    return _bgr_to_pil(frame)


async def extract_frames(file: UploadFile) -> list:
    """

    Read an uploaded video, sample up to MAX_FRAMES evenly across its

    full duration, and return them as a list of RGB PIL Images.



    OpenCV needs a real file path for video decoding, so we write to a

    temp file and delete it when done.

    """
    _check_ext(file.filename, ALLOWED_VIDEO)

    raw    = await file.read()
    suffix = _ext(file.filename)

    # Write to temp file — OpenCV cannot decode video from a memory buffer
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
    try:
        tmp.write(raw)
        tmp.close()

        cap = cv2.VideoCapture(tmp.name)
        if not cap.isOpened():
            raise HTTPException(status_code=400, detail="Cannot open video file.")

        total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total == 0:
            raise HTTPException(status_code=400, detail="Video has no frames.")

        # Evenly-spaced indices across the whole video
        n_samples = min(MAX_FRAMES, total)
        indices   = np.linspace(0, total - 1, n_samples, dtype=int)

        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
            ok, frame = cap.read()
            if ok:
                frames.append(_bgr_to_pil(frame))

        cap.release()

    finally:
        os.unlink(tmp.name)  # always clean up

    if not frames:
        raise HTTPException(status_code=400, detail="Failed to extract frames from video.")

    return frames