File size: 1,214 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// Copyright © 2024 Apple Inc.
#pragma once

#include <optional>

#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>

#include "mlx/array.h"
#include "mlx/ops.h"

namespace mx = mlx::core;
namespace nb = nanobind;

namespace nanobind {
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind

struct ArrayLike {
  ArrayLike(nb::object obj) : obj(obj) {};
  nb::object obj;
};

using ArrayInitType = std::variant<
    nb::bool_,
    nb::int_,
    nb::float_,
    // Must be above ndarray
    mx::array,
    // Must be above complex
    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
    std::complex<float>,
    nb::list,
    nb::tuple,
    ArrayLike>;

mx::array nd_array_to_mlx(
    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
    std::optional<mx::Dtype> dtype);

nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a);
nb::ndarray<> mlx_to_dlpack(const mx::array& a);

nb::object to_scalar(mx::array& a);

nb::object tolist(mx::array& a);

mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);