File size: 1,408 Bytes
af83196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# EVOLVE-BLOCK-START
"""
Initial Grayscale submission with Triton kernel.
Y = 0.2989 R + 0.5870 G + 0.1140 B
"""

import torch
import triton
import triton.language as tl


@triton.jit
def grayscale_kernel(
    rgb_ptr, out_ptr,
    H, W,
    stride_h, stride_w, stride_c,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    n_pixels = H * W
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_pixels

    h_idx = offsets // W
    w_idx = offsets % W

    r_ptr = rgb_ptr + h_idx * stride_h + w_idx * stride_w + 0 * stride_c
    g_ptr = rgb_ptr + h_idx * stride_h + w_idx * stride_w + 1 * stride_c
    b_ptr = rgb_ptr + h_idx * stride_h + w_idx * stride_w + 2 * stride_c

    r = tl.load(r_ptr, mask=mask)
    g = tl.load(g_ptr, mask=mask)
    b = tl.load(b_ptr, mask=mask)

    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    out_offsets = h_idx * W + w_idx
    tl.store(out_ptr + out_offsets, gray, mask=mask)


def custom_kernel(data):
    rgb, output = data
    H, W, C = rgb.shape
    assert C == 3
    rgb = rgb.contiguous()
    stride_h, stride_w, stride_c = rgb.stride()
    n_pixels = H * W
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_pixels, BLOCK_SIZE),)
    grayscale_kernel[grid](
        rgb, output, H, W,
        stride_h, stride_w, stride_c,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return output
# EVOLVE-BLOCK-END