File size: 1,289 Bytes
e05eed1 98a67a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <ostream>
#include <vector>
#include <torch/torch.h>
template<typename T>
inline
std::ostream &operator<<(std::ostream &os, const std::vector<T> &v) {
os << "[";
if (! v.empty()) {
os << v[0];
for (size_t i = 1; i < v.size(); ++i) {
os << ", " << v[i];
}
}
os << "]";
return os;
}
template<int Counter, typename ...Args>
struct _inner_tuple_print
{
inline
static std::ostream &print(std::ostream &os, const std::tuple<Args...> &t) {
_inner_tuple_print<Counter - 1, Args...>::print(os, t);
os << ", " << std::get<Counter>(t);
return os;
}
};
template<typename ...Args>
struct _inner_tuple_print<0, Args...>
{
inline
static std::ostream &print(std::ostream &os, const std::tuple<Args...> &t) {
os << std::get<0>(t);
return os;
}
};
template<typename... Args>
inline
std::ostream &operator<<(std::ostream &os, const std::tuple<Args...> &t) {
os << "(";
_inner_tuple_print<sizeof...(Args) - 1, Args...>::print(os, t);
os << ")";
return os;
}
void print_tensor(const torch::Tensor &t);
|