diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h new file mode 100644 index 0000000000000000000000000000000000000000..158813043af61883f4df398360c2894a303ac0ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include + +#include + +#include + +#include +#include +#include +#include + +namespace torch { +namespace data { + +/// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and +/// some `options`. +template +std::enable_if_t< + !Dataset::is_stateful, + std::unique_ptr>> +make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { + return std::make_unique>( + std::move(dataset), std::move(sampler), std::move(options)); +} + +/// Creates a `DataLoader` instance for a stateless `dataset` and some +/// `options`. A sampler (by default a `RandomSampler`) will be constructed from +/// the size of the dataset. +template +std::enable_if_t< + !Dataset::is_stateful && std::is_constructible_v, + std::unique_ptr>> +make_data_loader( + Dataset dataset, + DataLoaderOptions options = DataLoaderOptions()) { + const std::optional size = dataset.size(); + TORCH_CHECK( + size.has_value(), + "Expected the dataset to be sized in " + "order to construct the Sampler"); + return make_data_loader( + std::move(dataset), Sampler(*size), std::move(options)); +} + +/// Creates a `DataLoader` for a stateful `dataset` and some `options`. +template > +std::unique_ptr> make_data_loader( + Dataset dataset, + DataLoaderOptions options = DataLoaderOptions()) { + return std::make_unique>( + std::move(dataset), std::move(options)); +} +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h new file mode 100644 index 0000000000000000000000000000000000000000..cb17843ba0b33b7071e770e1e4e7b647b8443160 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h @@ -0,0 +1,255 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace data { +template +class DataLoaderBase { + public: + using BatchType = Batch; + using BatchRequestType = BatchRequest; + + /// Constructs a new DataLoader from a `dataset` to sample from, `options` + /// to configure the DataLoader with, and a `sampler` that specifies the + /// sampling strategy. + DataLoaderBase( + DataLoaderOptions options, + std::unique_ptr main_thread_dataset = nullptr) + : options_(std::move(options)), + main_thread_dataset_(std::move(main_thread_dataset)), + sequencer_(new_sequencer()) {} + + // NOLINTNEXTLINE(bugprone-exception-escape) + virtual ~DataLoaderBase() { + join(); + } + + /// Returns an iterator into the DataLoader. The lifetime of the iterator is + /// bound to the DataLoader. In C++ standards language, the category of the + /// iterator is `OutputIterator`. See + /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this + /// means. In short: you may increment the iterator and dereference it, but + /// cannot go back, or step forward more than one position at a time. When the + /// DataLoader is exhausted, it will compare equal with the special + /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you + /// should only use range-for loops to loop over the DataLoader, but + /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(), + /// output_iterator)` are supported too. + Iterator begin() { + TORCH_CHECK( + shuttle_.in_flight_jobs() == 0, + "Attempted to get a new DataLoader iterator " + "while another iterator is not yet exhausted"); + reset(); + return Iterator(std::make_unique>( + [this] { return this->next(); })); + } + + /// Returns a special "sentinel" iterator that compares equal with a + /// non-sentinel iterator once the DataLoader is exhausted. + Iterator end() { + return Iterator(std::make_unique>()); + } + + /// Joins the DataLoader's worker threads and drains internal queues. + /// This function may only be invoked from the main thread (in which the + /// DataLoader lives). + void join() { + if (joined_) { + return; + } + shuttle_.drain(); + // Send one 'quit' message per worker. Since a worker dies (exits its + // thread) after receiving this message, each `QuitWorker()` message will be + // read by exactly one worker. + for (const auto w : c10::irange(options_.workers)) { + (void)w; // Suppress unused variable warning + push_job(QuitWorker()); + } + for (auto& worker : workers_) { + worker.join(); + } + joined_ = true; + } + + /// Returns the options with which the DataLoader was configured. + const FullDataLoaderOptions& options() const noexcept { + return options_; + } + + protected: + /// Simple mix-in to give something a sequence number. + struct Sequenced { + Sequenced() = default; + Sequenced(size_t sqn) : sequence_number(sqn) {} + size_t sequence_number; + }; + + struct QuitWorker {}; + + /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a + /// `QuitWorker` object, to indicate the worker should shut down. + struct Job : Sequenced { + Job() = default; + Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {} + Job(BatchRequest&& i, size_t sqn) + : Sequenced(sqn), batch_request(std::move(i)) {} + std::optional quit; + std::optional batch_request; + }; + + /// The finished result of a job. + struct Result : Sequenced { + Result() = default; + Result(std::optional&& b, size_t sqn) + : Sequenced(sqn), batch(std::move(b)) {} + Result(std::exception_ptr exception, size_t sqn) + : Sequenced(sqn), exception(std::move(exception)) {} + std::optional batch; + std::exception_ptr exception; + }; + + /// Subclass hook for getting the next batch request. The stateless case will + /// ask the sampler for a new batch request (e.g. a vector of indices), while + /// the stateful one will simply return the batch size. + virtual std::optional get_batch_request() = 0; + + /// Resets the internal state of the DataLoader, optionally pre-fetching + /// new jobs. + virtual void reset() { + shuttle_.drain(); + sequence_number_ = 0; + sequencer_ = new_sequencer(); + prefetch(); + } + + /// Schedules `requested_jobs` many new batches to be fetched. The actual + /// number of jobs scheduled may be less if the DataLoader exhausts. + void prefetch(size_t requested_jobs) { + for (const auto r : c10::irange(requested_jobs)) { + (void)r; // Suppress unused variable + if (auto batch_request = get_batch_request()) { + this->push_job(std::move(*batch_request)); + } else { + break; + } + } + } + + /// Schedules the maximum number of jobs (based on the `max_jobs` option). + void prefetch() { + prefetch(options_.max_jobs); + } + + /// Returns the next batch of data, or an empty `optional` if the DataLoader + /// is exhausted. This operation will block until a batch is available if one + /// is still expected. + std::optional next() { + if (options_.workers > 0) { + while (std::optional result = this->pop_result()) { + if (result->exception) { + throw WorkerException(result->exception); + } else if (result->batch) { + prefetch(1); + return std::move(result->batch); + } + } + } else if (auto batch_request = get_batch_request()) { + return this->main_thread_dataset_->get_batch(std::move(*batch_request)); + } + return nullopt; + } + + /// The function that worker threads run. + void worker_thread(Dataset& dataset) { + while (true) { + auto job = shuttle_.pop_job(); + if (job.quit) { + break; + } + try { + auto batch = dataset.get_batch(std::move(*job.batch_request)); + shuttle_.push_result({std::move(batch), job.sequence_number}); + } catch (...) { + shuttle_.push_result({std::current_exception(), job.sequence_number}); + } + } + } + + /// Convenience method that calls `shuttle_.push_job()` with the next sequence + /// number. + template + void push_job(T value) { + shuttle_.push_job({std::move(value), sequence_number_++}); + } + + /// Convenience method that gets the next result from the sequencer. + std::optional pop_result() { + return sequencer_->next( + [this] { return this->shuttle_.pop_result(this->options_.timeout); }); + } + + /// Convenience method that creates a new sequencer based on the + /// `enforce_ordering` option. + std::unique_ptr> new_sequencer() { + if (options_.enforce_ordering) { + return std::make_unique>( + options_.max_jobs); + } + return std::make_unique>(); + } + + /// The options the DataLoader was configured with. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const FullDataLoaderOptions options_; + + /// The dataset for the main thread, only has a value if the number of + /// worker threads was configured as zero, meaning the main thread has to do + /// all the work (synchronously). NOTE: Really want this to be on the heap + /// when empty, therefore `unique_ptr` and not `optional`. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::unique_ptr main_thread_dataset_; + + /// The sequence number for the *next* batch to be retrieved from the + /// dataset. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t sequence_number_ = 0; + + /// The worker threads, running the `worker_thread()` method. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector workers_; + + /// The `DataShuttle` which takes care of the life cycle of a job. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + detail::DataShuttle shuttle_; + + /// The `Sequencer`, which handles optional ordering of batches. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::unique_ptr> sequencer_; + + /// True if the DataLoader has joined its worker threads. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool joined_ = false; +}; +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h new file mode 100644 index 0000000000000000000000000000000000000000..6ae027119a0c9959870e886f28a9b13b44d532f0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch { +namespace data { + +/// A dataloader for stateful datasets. +/// +/// A dataloader for stateful datatasets differs from one for stateless +/// datasets one in that the dataset is shared among worker threads, and that +/// this dataset is itself responsible for producing batches rather than +/// depending on a sampler. The statefulness here actually refers to the +/// dataset. The StatefulDataLoader simply alters the data loading algorithm to +/// accommodate the stateful, shared nature of the dataset. Note that the +/// dataset must be thread safe if more than one worker thread is used. +/// +/// A stateful dataloader is created by calling `make_data_loader` with a +/// stateful dataset. +template +class StatefulDataLoader : public DataLoaderBase< + Dataset, + typename Dataset::BatchType::value_type, + typename Dataset::BatchRequestType> { + public: + using super = DataLoaderBase< + Dataset, + typename Dataset::BatchType::value_type, + typename Dataset::BatchRequestType>; + using typename super::BatchRequestType; + + /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`. + StatefulDataLoader(Dataset dataset, DataLoaderOptions options) + : super(options, std::make_unique(std::move(dataset))) { + for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) { + // As opposed to the stateless case, here all worker threads access the + // same underlying dataset. + this->workers_.emplace_back( + [this] { this->worker_thread(*this->main_thread_dataset_); }); + } + } + + private: + /// Resets the internal state of the dataloader and the dataset. + void reset() override { + this->main_thread_dataset_->reset(); + // Call the base class method last because it calls `prefetch()` + super::reset(); + } + + /// For stateful datasets, the batch request is always the batch size. The + /// dataset is responsible for determining what goes into the batch next. + std::optional get_batch_request() override { + return this->options_.batch_size; + } +}; +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h new file mode 100644 index 0000000000000000000000000000000000000000..422b1097ee71b4a21217063609de8983d5883572 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace torch { +namespace data { + +/// A dataloader for stateless datasets. +/// +/// This dataloader follows the traditional PyTorch dataloader design, whereby a +/// (posssibly) stateful sampler produces *batch requests* for a stateless +/// dataset, which acts as a simple batch request to batch mapping. The batch +/// request will often be an array of indices, and if the dataset is a simple +/// image dataset, the dataset would produce the images at those indices. +template +class StatelessDataLoader : public DataLoaderBase< + Dataset, + typename Dataset::BatchType, + typename Sampler::BatchRequestType> { + public: + using super = DataLoaderBase< + Dataset, + typename Dataset::BatchType, + typename Sampler::BatchRequestType>; + using typename super::BatchRequestType; + + /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and + /// some `options`. + StatelessDataLoader( + Dataset dataset, + Sampler sampler, + DataLoaderOptions options) + : super(std::move(options)), sampler_(std::move(sampler)) { + for (const auto w : c10::irange(this->options_.workers)) { + // Here we copy the dataset into the worker thread closure. Each worker + // has its own copy of the dataset. This means the dataset must be + // trivially copiable, or else we don't expect more than one worker to + // be in use. + (void)w; // Suppress unused variable warning + this->workers_.emplace_back( + [this, dataset]() mutable { this->worker_thread(dataset); }); + } + if (this->options_.workers == 0) { + this->main_thread_dataset_ = + std::make_unique(std::move(dataset)); + } + } + + private: + /// Resets the internal state of the dataloader and the sampler. + void reset() override { + sampler_.reset(); + // Call the base class method last because it calls `prefetch()` + super::reset(); + } + + /// Queries the sampler for the next batch request (possibly progressing its + /// internal state). + std::optional get_batch_request() override { + auto indices = sampler_.next(this->options_.batch_size); + if (!indices || + (indices->size() < this->options_.batch_size && + this->options_.drop_last)) { + return nullopt; + } + AT_ASSERT(indices->size() > 0); + return indices; + } + + /// The `Sampler` used to produce batch requests. + Sampler sampler_; +}; +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h new file mode 100644 index 0000000000000000000000000000000000000000..a0c96aee0771370b1233c363e38e2777061eb115 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { +namespace data { + +/// Options to configure a `DataLoader`. +struct DataLoaderOptions { + DataLoaderOptions() = default; + /* implicit */ DataLoaderOptions(size_t batch_size) + : batch_size_(batch_size) {} + + /// The size of each batch to fetch. + TORCH_ARG(size_t, batch_size) = 1; + + /// The number of worker threads to launch. If zero, the main thread will + /// synchronously perform the data loading. + TORCH_ARG(size_t, workers) = 0; + + /// The maximum number of jobs to enqueue for fetching by worker threads. + /// Defaults to two times the number of worker threads. + TORCH_ARG(std::optional, max_jobs); + + /// An optional limit on the time to wait for the next batch. + TORCH_ARG(std::optional, timeout); + + /// Whether to enforce ordering of batches when multiple are loaded + /// asynchronously by worker threads. Set to `false` for better performance if + /// you do not care about determinism. + TORCH_ARG(bool, enforce_ordering) = true; + + /// Whether to omit the last batch if it contains less than `batch_size` + /// examples. + TORCH_ARG(bool, drop_last) = false; +}; + +/// Like `DataLoaderOptions`, but without any unconfigured state. +/// `DataLoaderOptions` has some options that depend on other options +/// (`max_jobs` => `2 * workers`). In the spirit of properly using the C++ type +/// system, `DataLoaderOptions` allows only setting values. To access values, +/// you must create a `FullDataLoaderOptions` from a `DataLoaderOptions` +/// instance, which will do any necessary coalescing. +struct FullDataLoaderOptions { + explicit FullDataLoaderOptions(DataLoaderOptions options) + : batch_size(options.batch_size()), + workers(options.workers()), + max_jobs(options.max_jobs().value_or(2 * workers)), + timeout(options.timeout()), + enforce_ordering(options.enforce_ordering()), + drop_last(options.drop_last()) {} + + size_t batch_size; + size_t workers; + size_t max_jobs; + std::optional timeout; + bool enforce_ordering; + bool drop_last; +}; +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h new file mode 100644 index 0000000000000000000000000000000000000000..df565e97235828e5c89c76f0373bc1cdaee01287 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h new file mode 100644 index 0000000000000000000000000000000000000000..f17b3fe8af47549dbe1921b753a50bda529a8ebb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace data { +namespace datasets { +template +class MapDataset; +template +MapDataset map(D, T); // NOLINT +} // namespace datasets +} // namespace data +} // namespace torch + +namespace torch { +namespace data { +namespace datasets { +namespace detail { +template +struct is_optional : std::false_type {}; +template +struct is_optional> : std::true_type {}; +} // namespace detail + +/// A dataset that can yield data only in batches. +template < + typename Self, + typename Batch = std::vector>, + typename BatchRequest = ArrayRef> +class BatchDataset { + public: + using SelfType = Self; + using BatchType = Batch; + using BatchRequestType = BatchRequest; + constexpr static bool is_stateful = detail::is_optional::value; + + virtual ~BatchDataset() = default; + + /// Returns a batch of data given an index. + virtual Batch get_batch(BatchRequest request) = 0; + + /// Returns the size of the dataset, or an empty std::optional if it is + /// unsized. + virtual std::optional size() const = 0; + + /// Creates a `MapDataset` that applies the given `transform` to this dataset. + template + MapDataset map(TransformType transform) & { + return datasets::map(static_cast(*this), std::move(transform)); + } + + /// Creates a `MapDataset` that applies the given `transform` to this dataset. + template + MapDataset map(TransformType transform) && { + return datasets::map( + std::move(static_cast(*this)), std::move(transform)); + } +}; + +/// A dataset that can yield data in batches, or as individual examples. +/// +/// A `Dataset` is a `BatchDataset`, because it supports random access and +/// therefore batched access is implemented (by default) by calling the random +/// access indexing function for each index in the requested batch of indices. +/// This can be customized. +template > +class Dataset : public BatchDataset> { + public: + using ExampleType = SingleExample; + + /// Returns the example at the given index. + virtual ExampleType get(size_t index) = 0; + + /// Returns a batch of data. + /// The default implementation calls `get()` for every requested index + /// in the batch. + std::vector get_batch(ArrayRef indices) override { + std::vector batch; + batch.reserve(indices.size()); + for (const auto i : indices) { + batch.push_back(get(i)); + } + return batch; + } +}; + +/// A `StreamDataset` represents a dataset that is a potentially infinite +/// stream. It takes as batch index only a number, which is the batch size, and +/// yields that many elements from the stream. +template >> +using StreamDataset = BatchDataset; +} // namespace datasets +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h new file mode 100644 index 0000000000000000000000000000000000000000..01d940aa3e4885e9c78d527dc56d950b2ecad23a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -0,0 +1,529 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace data { +namespace datasets { + +/// Interface for chunk reader, which performs data chunking and reading of +/// entire chunks. +/// +/// A chunk could be an entire file, such as an audio data file or an image, +/// or part of a file in the case of a large text-file split based on seek +/// positions. +template < + typename ExampleType_, + typename ChunkType_ = std::vector> +class ChunkDataReader { + public: + virtual ~ChunkDataReader() = default; + + using ChunkType = ChunkType_; + using ExampleType = ExampleType_; + + /// Read an entire chunk. + virtual ChunkType read_chunk(size_t chunk_index) = 0; + + /// Returns the number of chunks available in this reader. + virtual size_t chunk_count() = 0; + + /// This will clear any internal state associate with this reader. + virtual void reset() = 0; +}; + +namespace detail { +/// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is +/// loaded, BatchDataBuffer splits it into small batches and push them into the +/// queue. When get_batch is called from data loader, it pops cached batches and +/// return. If the cache is empty, it either waits to load more chunks or return +/// null if all chunks are loaded. +template < + typename UnwrappedBatch, + typename ExampleSampler = samplers::RandomSampler> +class BatchDataBuffer { + public: + using UnwrappedBatchType = UnwrappedBatch; + using BatchType = torch::optional; + using BatchRequestType = typename ExampleSampler::BatchRequestType; + + BatchDataBuffer( + size_t batch_size, + ExampleSampler& example_sampler, + size_t queue_capacity) + : batch_size_(batch_size), + example_sampler_(example_sampler), + queue_capacity_(queue_capacity) {} + + /// Return batch data from the queue. Called from the ChunkDataset main + /// thread. + BatchType get_batch() { + std::unique_lock lock(queue_mutex_); + cv_read_.wait(lock, [this] { + // wait till there is available data in the queue or if all chunks are + // loaded (i.e. the dataset is exhausted for this epoch) + return ( + this->total_example_count_in_queue_ >= batch_size_ || this->stop_); + }); + if (batch_queue_.empty()) { + AT_ASSERT(stop_); + // All batches have been retrieved. Return an empty batch. + return nullopt; + } + + UnwrappedBatchData batch = std::move(batch_queue_.front()); + batch_queue_.pop(); + if (batch.exception) { + throw WorkerException(batch.exception); + } + + total_example_count_in_queue_ -= batch.batch_data.size(); + lock.unlock(); + cv_write_.notify_all(); + + return batch.batch_data; + } + + /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker + /// threads. + void add_chunk_data(UnwrappedBatchType data) { + std::unique_lock lock(queue_mutex_); + cv_write_.wait(lock, [this] { + // stop loading if we have preloaded enough data. + return this->total_example_count_in_queue_ < this->queue_capacity_ || + this->stop_; + }); + if (stop_) { + // When stop_ is true, it means no further chunk loading is necessary. + // Return without any further processing. + return; + } + + auto data_size = data.size(); + auto remaining_size = data_size; + example_sampler_.reset(data_size); + + auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) { + auto batch_example_indices = this->example_sampler_.next(example_count); + AT_ASSERT( + batch_example_indices && + batch_example_indices.value().size() == example_count); + BatchRequestType& indices = batch_example_indices.value(); + for (size_t i : indices) { + TORCH_CHECK(i < data_size, "Index out of range"); + batch.emplace_back(std::move(data[i])); + } + remaining_size -= example_count; + }; + + if (!batch_queue_.empty()) { + // if the queue has existing data, and the last batch doesn't have enough + // examples to fill a batch_size batch, add more example to this batch + // first. + auto& batch = batch_queue_.back(); + size_t current_count = batch.batch_data.size(); + if (current_count < batch_size_) { + auto example_count = + std::min(remaining_size, batch_size_ - current_count); + fill_batch(example_count, batch.batch_data); + } + } + + // If we still have data remaining after filling the last pushed batch, add + // them to the queue too. + // NOLINTNEXTLINE(bugprone-infinite-loop) + while (remaining_size > 0) { + UnwrappedBatchType current_batch; + + // Allocate the batch memory ahead of time. + current_batch.reserve(batch_size_); + + auto example_count = std::min(remaining_size, batch_size_); + fill_batch(example_count, current_batch); + batch_queue_.emplace(std::move(current_batch)); + } + total_example_count_in_queue_ += data_size; + lock.unlock(); + cv_read_.notify_all(); + } + + /// Push exceptions thrown during preloading into batch queue. Called from + /// the ChunkDataset worker threads. + void add_chunk_data(std::exception_ptr e_ptr) { + std::unique_lock lock(queue_mutex_); + cv_write_.wait(lock, [this] { + // stop loading if we have preloaded enough data. + return ( + this->total_example_count_in_queue_ < this->queue_capacity_ || + this->stop_); + }); + if (stop_) { + // When stop_ is true, it means this current thread needs to be tore down, + // the batch buffer will be discarded, so no need to enqueue any new + // exceptions. + return; + } + + batch_queue_.emplace(e_ptr); + lock.unlock(); + cv_read_.notify_all(); + } + + void stop() { + { + // Hold the lock before changing stop_ to prevent a race condition which + // can cause a deadlock. To be more specific, conditional variable + // cv_write_ waits on predicate stop_ in add_chunk_data(). The wait + // happens in two steps: 1) while still holding the lock, check if + // predicate is true; 2) if it is true, proceeds, otherwise, release the + // lock and wait until notified. Without holding a lock, cv_write_'s + // notification can happen in between step 1) and 2). In that case, as + // cv_write_ is not in waiting status yet, so the notification is lost and + // cv_write_ will sleep forever. By taking a lock before changing + // predicate stop_, it is ensured updating and evaluating stop_ always + // happen in a synchronized way + std::lock_guard lock(queue_mutex_); + stop_ = true; + } + + // notify all writers, wake them from wait to exit current method. + cv_write_.notify_all(); + // notify all readers too. + cv_read_.notify_all(); + } + /// The batch size is needed to create batches from the chunk data. Similar to + /// regular dataloader where the batches are created with prefetches, + /// BatchDataBuffer perform the batch creation using the provided batch size. + size_t batch_size_ = 0; + + /// count of total example stored in the queue + size_t total_example_count_in_queue_ = 0; + + /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit + /// is the raw data without 'optional' wrapper. It can be a collection of + /// images, utterances, e.t.c. + struct UnwrappedBatchData { + explicit UnwrappedBatchData(UnwrappedBatchType data) + : batch_data(std::move(data)) {} + + // NOLINTNEXTLINE(modernize-pass-by-value) + explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {} + + /// batch data to return + UnwrappedBatchType batch_data; + + /// exception pointer which captures any abnormal exceptions while creating + /// the batch. + std::exception_ptr exception; + }; + + /// local cache to store example batches from loaded chunk + std::queue batch_queue_; + + // sync batch_queue_ update. + std::mutex queue_mutex_; + + std::condition_variable cv_read_; + std::condition_variable cv_write_; + + ExampleSampler& example_sampler_; + + // configurable maximun number of elements the queue can hold at one time. + size_t queue_capacity_; + + // When set to true, it wakes the writer threads from the wait and exit + // current function call. This is needed when ChunkDataSet.Reset is called + // while the previous epoch is not exhausted yet. When ChunkDataset is waiting + // its preloader to finish previous work before tearing down the thread, the + // preloader could be still waiting for the conditional variable, thus cause + // the program to hang. This boolean is used to break this waiting condition. + bool stop_ = false; +}; +} // namespace detail + +/// Options to configure a `ChunkDataset`. +struct ChunkDatasetOptions { + ChunkDatasetOptions() = delete; + ChunkDatasetOptions( + size_t preloader_count, + size_t batch_size, + size_t cache_size = 2048, + size_t cross_chunk_shuffle_count = 1) + : preloader_count_(preloader_count), + batch_size_(batch_size), + cache_size_(cache_size), + cross_chunk_shuffle_count_(cross_chunk_shuffle_count) { + TORCH_CHECK( + preloader_count_ > 0, + "Preloader count is 0. At least one preloader needs to be specified."); + TORCH_CHECK( + batch_size_ > 0, + "Batch size is 0. A positive batch size needs to be specified."); + TORCH_CHECK( + cache_size_ > 0, + "Cache size is 0. A positive cache size needs to be specified."); + TORCH_CHECK( + cache_size_ >= batch_size_, + "Cache size is less than batch size. Cache needs to be large enough to " + "hold at least one batch."); + TORCH_CHECK( + cross_chunk_shuffle_count_ > 0, + "cross_chunk_shuffle_count needs to be greater than 0."); + } + + /// The number of worker thread to preload chunk data. + TORCH_ARG(size_t, preloader_count); + + /// The size of each batch. + TORCH_ARG(size_t, batch_size); + + /// The capacity of the queue for batch caching. + TORCH_ARG(size_t, cache_size) = 2048; + + // The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning + // no cross-chunk shuffling. When it is equal to n (n > 1), n random + // chunks will be loaded at once and example shuffling will be performed + // across all those n chunks. + // Note: Usually the default config (1 chunk shuffle + example shuffle) is + // good enough to generate random distributed data. Use this parameter only if + // you know cross-shuffle is needed in your case. Also there is a performance + // penalty when this value is greater than 1, as we need to do extra merge + // between multiple chunks before performing example sampling. + TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1; +}; + +/// A stateful dataset that support hierarchical sampling and prefetching of +/// entre chunks. +/// +/// Unlike regular dataset, chunk dataset require two samplers to operate and +/// keeps an internal state. `ChunkSampler` selects, which chunk to load next, +/// while the `ExampleSampler` determins the order of Examples that are returned +/// in each `get_batch` call. The hierarchical sampling approach used here is +/// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf +template < + typename ChunkReader, + typename ChunkSampler = samplers::RandomSampler, + typename ExampleSampler = samplers::RandomSampler> +class ChunkDataset final + : public StatefulDataset< + ChunkDataset, + typename ChunkReader::BatchType, + size_t> { + public: + using BatchType = torch::optional; + using UnwrappedBatchType = typename ChunkReader::BatchType; + using BatchRequestType = size_t; + using ChunkSamplerType = ChunkSampler; + using ExampleSamplerType = ExampleSampler; + + ChunkDataset( + ChunkReader chunk_reader, + ChunkSampler chunk_sampler, + ExampleSampler example_sampler, + ChunkDatasetOptions options, + std::function preprocessing_policy = + std::function()) + : chunk_reader_(std::move(chunk_reader)), + chunk_sampler_(std::move(chunk_sampler)), + example_sampler_(std::move(example_sampler)), + options_(std::move(options)), + preprocessing_policy_(std::move(preprocessing_policy)), + quit_worker_(false), + running_preloaders_(0), + load_checkpoint_(false) {} + + ~ChunkDataset() override { + // stop batch buffer first. + if (batch_buffer_) { + batch_buffer_->stop(); + } + free_workers(); + } + + /// Default get_batch method of BatchDataset. This method returns + /// Example batches created from the preloaded chunks. The implemenation + /// is dataset agnostic and does not need overriding in different chunk + /// datasets. + BatchType get_batch(size_t batch_size) override { + TORCH_CHECK( + batch_buffer_ != nullptr, + "Dataset needs to call reset() before calling get_batch()."); + + TORCH_CHECK( + batch_size == options_.batch_size(), + "The requested batch size does not match with the initialized batch size.\n" + " The requested batch size is ", + batch_size, + ", while the dataset is created with batch size equal to ", + options_.batch_size()); + return batch_buffer_->get_batch(); + } + + /// Helper method around get_batch as `batch_size` is not strictly necessary + BatchType get_batch() { + return get_batch(options_.batch_size()); + } + + /// This will clear any internal state and starts the internal prefetching + /// mechanism for the chunk dataset. + void reset() override { + // We need this to support partial data reads via dataloader iterator. + if (batch_buffer_) { + batch_buffer_->stop(); + } + // free workers from previous reset if there is any. + free_workers(); + preload_threads_.clear(); + + if (!load_checkpoint_) { + chunk_reader_.reset(); + chunk_sampler_.reset(chunk_reader_.chunk_count()); + load_checkpoint_ = false; + } + + // Throw out any existing cached batch in the buffer and re-creates a new + // chunk buffer. + batch_buffer_ = std::make_unique< + detail::BatchDataBuffer>( + options_.batch_size(), example_sampler_, options_.cache_size()); + + // create new workers for this new epoch. + quit_worker_ = false; + + AT_ASSERT(running_preloaders_ == 0); + running_preloaders_ = options_.preloader_count(); + for (const auto i : c10::irange(options_.preloader_count())) { + preload_threads_.emplace_back([this, i]() { this->preloader(i); }); + } + } + + /// size is not used for chunk dataset. + std::optional size() const override { + return torch::nullopt; + } + + // provide a references to chunk sampler. Used mainly in distributed data + // loading to set the epoch number for the sampler. + ChunkSamplerType& chunk_sampler() { + return chunk_sampler_; + } + + void save(serialize::OutputArchive& archive) const override { + std::lock_guard lock(chunk_index_guard_); + chunk_sampler_.save(archive); + } + + void load(serialize::InputArchive& archive) override { + std::lock_guard lock(chunk_index_guard_); + chunk_sampler_.load(archive); + load_checkpoint_ = true; + } + + private: + /// running on worker thread to preload chunk data. + void preloader(size_t id) { + while (!quit_worker_.load()) { + try { + std::vector chunk_idx; + { + std::lock_guard lock(chunk_index_guard_); + if (auto chunk_sampler_result = chunk_sampler_.next( + this->options_.cross_chunk_shuffle_count())) { + chunk_idx = chunk_sampler_result.value(); + } else { + break; + } + } + UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]); + for (const auto i : c10::irange(1, chunk_idx.size())) { + auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]); + std::move( + chunk_data.begin(), chunk_data.end(), std::back_inserter(data)); + } + if (preprocessing_policy_) { + preprocessing_policy_(data); + } + if (!data.empty()) { // skip empty chunks. + batch_buffer_->add_chunk_data(std::move(data)); + } + } catch (...) { + batch_buffer_->add_chunk_data(std::current_exception()); + } + } + AT_ASSERT(running_preloaders_.load() > 0); + --running_preloaders_; + if (running_preloaders_.load() == 0) { + // all preloaders are completed, so we can notify the batch_buffer. + batch_buffer_->stop(); + } + } + + /// Block the current thread until the workers finish execution and exit. + void free_workers() { + if (!quit_worker_.load()) { + quit_worker_ = true; + for (auto& worker_thread : preload_threads_) { + worker_thread.join(); + } + } + } + + private: + // Templated class that defines what is a chunk and how to read chunk data. + // When a chunk is returned by chunk_reader_, ChunkDataset split it into + // batches and caches them in batch_buffer_. + ChunkReader chunk_reader_; + + // chunk sampler to shuffle different chunks + ChunkSamplerType chunk_sampler_; + + // example sampler to shuffle examples in a specific chunk + ExampleSamplerType example_sampler_; + + // batch data buffer which holds chunk data from preloading thread. + std::shared_ptr< + detail::BatchDataBuffer> + batch_buffer_; + + // worker thread pool + std::vector preload_threads_; + + /// The options the Dataset was configured with. + const ChunkDatasetOptions options_; + + // function pointer wrapper to apply custom processing over chunk data. This + // is considered an advanced parameter for developers who want to apply a + // pre-process to the chunk data before sampling into minibatch. + // Different than the collate function, this policy is applied on the chunk + // level, instead of minibatch level. When a chunk of data is loaded (multiple + // chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is + // applied to the full loaded data. It is useful if developers want to + // perform pre-processing (like bucketing) to the chunk data before + // example sampler samples the data. By default it's an empty pointer and no + // action will be taken. + std::function preprocessing_policy_; + + // indicate whether the worker thread can be teared down + std::atomic quit_worker_; + + // keep track of running preloaders to notify batch buffer. A value 0 + // indicates that the chunk loading is completed. + std::atomic running_preloaders_; + + // mutex to synchronize chunk sampler next() call. + mutable std::mutex chunk_index_guard_; + + // boolean value to indicate whether we need to load the checkpoint for + // chunk_sampler_. + bool load_checkpoint_; +}; +} // namespace datasets +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h new file mode 100644 index 0000000000000000000000000000000000000000..ebd4374cca8f3fcb765b937f728b9245d5998bfd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include + +namespace torch { +namespace data { +namespace datasets { +namespace detail { +template +using optional_if_t = typename std::conditional, T>::type; +} // namespace detail + +/// A `MapDataset` is a dataset that applies a transform to a source dataset. +template +class MapDataset : public BatchDataset< + MapDataset, + detail::optional_if_t< + SourceDataset::is_stateful, + typename AppliedTransform::OutputBatchType>, + typename SourceDataset::BatchRequestType> { + public: + using DatasetType = SourceDataset; + using TransformType = AppliedTransform; + using BatchRequestType = typename SourceDataset::BatchRequestType; + using OutputBatchType = detail::optional_if_t< + SourceDataset::is_stateful, + typename AppliedTransform::OutputBatchType>; + + MapDataset(DatasetType dataset, TransformType transform) + : dataset_(std::move(dataset)), transform_(std::move(transform)) {} + + /// Gets a batch from the source dataset and applies the transform to it, + /// returning the result. + OutputBatchType get_batch(BatchRequestType indices) override { + return get_batch_impl(std::move(indices)); + } + + /// Returns the size of the source dataset. + // NOLINTNEXTLINE(bugprone-exception-escape) + std::optional size() const noexcept override { + return dataset_.size(); + } + + /// Calls `reset()` on the underlying dataset. + /// NOTE: Stateless datasets do not have a reset() method, so a call to this + /// method will only compile for stateful datasets (which have a reset() + /// method). + void reset() { + dataset_.reset(); + } + + /// Returns the underlying dataset. + const SourceDataset& dataset() noexcept { + return dataset_; + } + + /// Returns the transform being applied. + const AppliedTransform& transform() noexcept { + return transform_; + } + + private: + /// The implementation of `get_batch()` for the stateless case, which simply + /// applies the transform to the output of `get_batch()` from the dataset. + template < + typename D = SourceDataset, + typename = std::enable_if_t> + OutputBatchType get_batch_impl(BatchRequestType indices) { + return transform_.apply_batch(dataset_.get_batch(std::move(indices))); + } + + /// The implementation of `get_batch()` for the stateful case. Here, we follow + /// the semantics of `Optional.map()` in many functional languages, which + /// applies a transformation to the optional's content when the optional + /// contains a value, and returns a new optional (of a different type) if the + /// original optional returned by `get_batch()` was empty. + template + std::enable_if_t get_batch_impl( + BatchRequestType indices) { + if (auto batch = dataset_.get_batch(std::move(indices))) { + return transform_.apply_batch(std::move(*batch)); + } + return nullopt; + } + + /// The underlying dataset being transformed. + SourceDataset dataset_; + + // The transformation that is applied to batches received from the dataset. + AppliedTransform transform_; +}; + +/// Creates a `MapDataset` with the given dataset and transform. +template +MapDataset map( + DatasetType dataset, + TransformType transform) { + static_assert( + std::is_same< + typename std::conditional< + DatasetType::is_stateful, + typename DatasetType::BatchType::value_type, + typename DatasetType::BatchType>::type, + typename TransformType::InputBatchType>::value, + "BatchType type of dataset does not match input type of transform"); + return {std::move(dataset), std::move(transform)}; +} + +} // namespace datasets +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h new file mode 100644 index 0000000000000000000000000000000000000000..5d9e352f36d07f38f688ff1c9d63a3d7505d137d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace torch { +namespace data { +namespace datasets { +/// The MNIST dataset. +class TORCH_API MNIST : public Dataset { + public: + /// The mode in which the dataset is loaded. + enum class Mode { kTrain, kTest }; + + /// Loads the MNIST dataset from the `root` path. + /// + /// The supplied `root` path should contain the *content* of the unzipped + /// MNIST dataset, available from http://yann.lecun.com/exdb/mnist. + explicit MNIST(const std::string& root, Mode mode = Mode::kTrain); + + /// Returns the `Example` at the given `index`. + Example<> get(size_t index) override; + + /// Returns the size of the dataset. + std::optional size() const override; + + /// Returns true if this is the training subset of MNIST. + // NOLINTNEXTLINE(bugprone-exception-escape) + bool is_train() const noexcept; + + /// Returns all images stacked into a single tensor. + const Tensor& images() const; + + /// Returns all targets stacked into a single tensor. + const Tensor& targets() const; + + private: + Tensor images_, targets_; +}; +} // namespace datasets +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h new file mode 100644 index 0000000000000000000000000000000000000000..aff84b586c89cec8870fc996f73009b8271b52e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include + +namespace torch { +namespace data { +namespace datasets { + +/// A dataset that wraps another dataset in a shared pointer and implements the +/// `BatchDataset` API, delegating all calls to the shared instance. This is +/// useful when you want all worker threads in the dataloader to access the same +/// dataset instance. The dataset must take care of synchronization and +/// thread-safe access itself. +/// +/// Use `torch::data::datasets::make_shared_dataset()` to create a new +/// `SharedBatchDataset` like you would a `std::shared_ptr`. +template +class SharedBatchDataset : public BatchDataset< + SharedBatchDataset, + typename UnderlyingDataset::BatchType, + typename UnderlyingDataset::BatchRequestType> { + public: + using BatchType = typename UnderlyingDataset::BatchType; + using BatchRequestType = typename UnderlyingDataset::BatchRequestType; + + /// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the + /// `UnderlyingDataset`. + /* implicit */ SharedBatchDataset( + std::shared_ptr shared_dataset) + : dataset_(std::move(shared_dataset)) {} + + /// Calls `get_batch` on the underlying dataset. + BatchType get_batch(BatchRequestType request) override { + return dataset_->get_batch(std::move(request)); + } + + /// Returns the `size` from the underlying dataset. + std::optional size() const override { + return dataset_->size(); + } + + /// Accesses the underlying dataset. + UnderlyingDataset& operator*() { + return *dataset_; + } + + /// Accesses the underlying dataset. + const UnderlyingDataset& operator*() const { + return *dataset_; + } + + /// Accesses the underlying dataset. + UnderlyingDataset* operator->() { + return dataset_.get(); + } + + /// Accesses the underlying dataset. + const UnderlyingDataset* operator->() const { + return dataset_.get(); + } + + /// Calls `reset()` on the underlying dataset. + void reset() { + dataset_->reset(); + } + + private: + std::shared_ptr dataset_; +}; + +/// Constructs a new `SharedBatchDataset` by creating a +/// `shared_ptr`. All arguments are forwarded to +/// `make_shared`. +template +SharedBatchDataset make_shared_dataset(Args&&... args) { + return std::make_shared(std::forward(args)...); +} +} // namespace datasets +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h new file mode 100644 index 0000000000000000000000000000000000000000..fb2379c673340fb9b836d1a60088cfc44f3103c5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace data { +namespace datasets { + +/// A stateful dataset is a dataset that maintains some internal state, which +/// will be `reset()` at the beginning of each epoch. Subclasses can override +/// the `reset()` method to configure this behavior. Further, the return type of +/// a stateful dataset's `get_batch()` method is always an `optional`. When the +/// stateful dataset wants to indicate to the dataloader that its epoch has +/// ended, it should return an empty optional. The dataloader knows to modify +/// its implementation based on whether the dataset is stateless or stateful. +/// +/// Note that when subclassing a from `StatefulDataset`, the return +/// type of `get_batch()`, which the subclass must override, will be +/// `optional` (i.e. the type specified in the `StatefulDataset` +/// specialization is automatically boxed into an `optional` for the dataset's +/// `BatchType`). +template < + typename Self, + typename Batch = std::vector>, + typename BatchRequest = size_t> +class StatefulDataset + : public BatchDataset, BatchRequest> { + public: + /// Resets internal state of the dataset. + virtual void reset() = 0; + + /// Saves the statefulDataset's state to OutputArchive. + virtual void save(serialize::OutputArchive& archive) const = 0; + + /// Deserializes the statefulDataset's state from the `archive`. + virtual void load(serialize::InputArchive& archive) = 0; +}; + +/// Serializes a statefulDataset to `OutputArchive`. +template +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const StatefulDataset& statefulDataset) { + statefulDataset.save(archive); + return archive; +} + +/// Deserializes a statefulDataset from an `InputArchive`. +template +serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + StatefulDataset& statefulDataset) { + statefulDataset.load(archive); + return archive; +} + +} // namespace datasets +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..4968e263009f3548a2f7ed41e47586f4290a093c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace data { +namespace datasets { + +/// A dataset of tensors. +/// Stores a single tensor internally, which is then indexed inside `get()`. +struct TensorDataset : public Dataset { + /// Creates a `TensorDataset` from a vector of tensors. + explicit TensorDataset(const std::vector& tensors) + : TensorDataset(torch::stack(tensors)) {} + + explicit TensorDataset(torch::Tensor tensor) : tensor(std::move(tensor)) {} + + /// Returns a single `TensorExample`. + TensorExample get(size_t index) override { + return tensor[index]; + } + + /// Returns the number of tensors in the dataset. + std::optional size() const override { + return tensor.size(0); + } + + Tensor tensor; +}; + +} // namespace datasets +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h new file mode 100644 index 0000000000000000000000000000000000000000..9c3ef121160123ecc048cfb5bbd3eac933f565ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include + +namespace torch { +namespace data { +namespace detail { + +/// Encapsulates the full life cycle of DataLoader jobs. +/// +/// When a new job is enqueued to the `DataShuttle`, a counter for in-flight +/// jobs is bumped. This job is said to be "in-flight" until its result is +/// popped. Worker threads dequeue jobs as soon as they are available. When a +/// worker finishes a job, it enqueues the result. Only when the main thread +/// dequeues a result is the count of in-flight jobs decremented. When the main +/// thread attempts to dequeue a job but no jobs are in-flight, that means the +/// epoch is complete and `pop_result` returns an empty optional. +template +class DataShuttle { + public: + /// Pushes a new job. Called by the main thread. + void push_job(Job job) { + new_jobs_.push(std::move(job)); + ++in_flight_jobs_; + } + + /// Pushes the result of a job. Called by worker threads. + void push_result(Result result) { + results_.push(std::move(result)); + } + + /// Returns the next job, blocking until there is one available. Called by + /// worker threads. + Job pop_job() { + return new_jobs_.pop(); + } + + /// Returns the result of a job, or nullopt if all jobs were exhausted. Called + /// by the main thread. + std::optional pop_result( + std::optional timeout = std::nullopt) { + if (in_flight_jobs_ > 0) { + auto result = results_.pop(timeout); + --in_flight_jobs_; + return result; + } + return nullopt; + } + + /// Discards any jobs that are not yet in flight, and waits for all in-flight + /// jobs to finish, discarding their result. + void drain() { + // Clear all inputs so that no further jobs are scheduled. + auto number_cleared = new_jobs_.clear(); + in_flight_jobs_ -= number_cleared; + // Remove any outstanding results. + while (in_flight_jobs_ > 0) { + pop_result(); + } + } + + /// Returns the number of jobs that are still in progress. + /// When this number is zero, an epoch is finished. + size_t in_flight_jobs() const noexcept { + return in_flight_jobs_; + } + + private: + /// The queue for jobs that are not yet in flight. + Queue new_jobs_; + /// The number of in-flight jobs. + /// NOTE: Not atomic because only manipulated by the main thread. + size_t in_flight_jobs_ = 0; + /// The queue for results of finished jobs. + Queue results_; +}; + +} // namespace detail +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h new file mode 100644 index 0000000000000000000000000000000000000000..60236ab3f520c3499de30eb20d8962a50f5c0eef --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h @@ -0,0 +1,84 @@ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace data { +namespace detail { + +/// A basic locked, blocking MPMC queue. +/// +/// Every `push` and `pop` is guarded by a mutex. A condition variable is used +/// to communicate insertion of new elements, such that waiting threads will be +/// woken up if they are currently waiting inside a call to `pop()`. +/// +/// Note that this data structure is written specifically for use with the +/// `DataLoader`. Its behavior is tailored to this use case and may not be +/// applicable to more general uses. +template +class Queue { + public: + /// Pushes a new value to the back of the `Queue` and notifies one thread on + /// the waiting side about this event. + void push(T value) { + { + std::lock_guard lock(mutex_); + queue_.push(std::move(value)); + } + cv_.notify_one(); + } + + /// Blocks until at least one element is ready to be popped from the front of + /// the queue. An optional `timeout` in seconds can be used to limit the time + /// spent waiting for an element. If the wait times out, an exception is + /// raised. + T pop(std::optional timeout = std::nullopt) { + std::unique_lock lock(mutex_); + if (timeout) { + if (!cv_.wait_for( + lock, *timeout, [this] { return !this->queue_.empty(); })) { + // clang-format off + AT_ERROR( + "Timeout in DataLoader queue while waiting for next batch" + " (timeout was ", timeout->count(), " ms)"); + // clang-format on + } + } else { + cv_.wait(lock, [this] { return !this->queue_.empty(); }); + } + AT_ASSERT(!queue_.empty()); + T value = queue_.front(); + queue_.pop(); + lock.unlock(); + return value; + } + + /// Empties the queue and returns the number of elements that were present at + /// the start of the function. No threads are notified about this event as it + /// is assumed to be used to drain the queue during shutdown of a + /// `DataLoader`. + size_t clear() { + std::lock_guard lock(this->mutex_); + const auto size = queue_.size(); + while (!queue_.empty()) { + queue_.pop(); + } + return size; + } + + private: + std::queue queue_; + std::mutex mutex_; + std::condition_variable cv_; +}; +} // namespace detail +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h new file mode 100644 index 0000000000000000000000000000000000000000..c59f4cd7e290df9835bb6f08ecb6e3f638c98ef4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h @@ -0,0 +1,113 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch { +namespace data { +namespace detail { +namespace sequencers { +namespace detail { +template +bool buffer_contains_result(const std::vector>& buffer) { + return std::any_of( + buffer.begin(), buffer.end(), [](const std::optional& result) { + return result.has_value(); + }); +} +} // namespace detail + +/// A `Sequencer` accepts a function that yields the next result of a +/// `DataLoader` and then has the opportunity to influence the order in which +/// these results are returned. The `NoSequencer` does not enforce any +/// sequencing and returns any result directly. The `OrderedSequencer` instead +/// buffers results internally to return them in order of their sequence number. +template +struct Sequencer { + using ResultProducer = std::function()>; + virtual ~Sequencer() = default; + virtual std::optional next(ResultProducer next_result) = 0; +}; + +/// A `Sequencer` that does not enforce any ordering. It is effectively the +/// identity function. +template +struct NoSequencer final : public Sequencer { + using typename Sequencer::ResultProducer; + std::optional next(ResultProducer next_result) override { + return next_result(); + } +}; + +/// A `Sequencer` that buffers results and returns them in order of their +/// sequence number. The `OrderedSequencer` maintains an internal, monotonically +/// incrementing counter for the next sequence number it expects. If it receives +/// a result with a higher sequence number, it will buffer it for later (when +/// the sequence number reaches that of this result). Otherwise, if the sequence +/// numbers match, the result is returned. +/// +/// Implementation note: The `OrderedSequencer` is implemented with a fixed-size +/// buffer. Let `m` be the maximum number of jobs in the data loader's queue and +/// `s` be the current sequence number. Assume `m` jobs are scheduled in the +/// `DataLoader`. Any new result is stored at index `job.sqn mod m` in the +/// `OrderedSequencer`. Why are we sure sequence numbers of new jobs will not +/// collide with sequence numbers of buffered jobs? The `OrderedSequencer` will +/// not return from `next()` until it receives the result with sqn `s`. This +/// means no new jobs can be scheduled in the `DataLoader` in the meantime, +/// which enforces that as long as sqn `s` has not been received, `s + m` (which +/// would cause a collision in the fixed-size buffer) will not yet be scheduled. +template +struct OrderedSequencer : public Sequencer { + using typename Sequencer::ResultProducer; + + /// Constructs the `OrderedSequencer` with the maximum number of results it + /// will ever hold at one point in time. + explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {} + + /// Buffers results until the next one in the expected order is received. + std::optional next(ResultProducer next_result) override { + // If we already have the result for the next sqn, return it. + if (auto& maybe_result = buffer(next_sequence_number_)) { + auto result = std::move(*maybe_result); + buffer(next_sequence_number_++).reset(); + return result; + } + // Otherwise wait for the next result. + while (true) { + auto result = next_result(); + if (!result) { + AT_ASSERT(!detail::buffer_contains_result(buffer_)); + break; + } + // If it was not nullopt and the sequence numbers match, return it + // directly and bump the sequence number. + if (result->sequence_number == next_sequence_number_) { + ++next_sequence_number_; + return result; + } + // Stash the result for later. + AT_ASSERT(!buffer(result->sequence_number).has_value()); + buffer(result->sequence_number) = std::move(result); + } + // The result was an empty optional, so we are done with this epoch. + return nullopt; + } + + /// Accesses the buffer at the `index` modulo the buffer size. + std::optional& buffer(size_t index) { + return buffer_.at(index % buffer_.size()); + } + + /// The monotonically increasing sequence number we expect. + size_t next_sequence_number_ = 0; + + /// A fixed-size buffer (after construction). + std::vector> buffer_; +}; +} // namespace sequencers +} // namespace detail +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h new file mode 100644 index 0000000000000000000000000000000000000000..57219a24cd0b08d65f8f5e46b80c0f8906a0ab03 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h @@ -0,0 +1,55 @@ +#pragma once + +#include + +namespace torch { +namespace data { + +/// An `Example` from a dataset. +/// +/// A dataset consists of data and an associated target (label). +template +struct Example { + using DataType = Data; + using TargetType = Target; + + Example() = default; + Example(Data data, Target target) + : data(std::move(data)), target(std::move(target)) {} + + Data data; + Target target; +}; + +namespace example { +using NoTarget = void; +} // namespace example + +/// A specialization for `Example` that does not have a target. +/// +/// This class exists so that code can be written for a templated `Example` +/// type, and work both for labeled and unlabeled datasets. +template +struct Example { + using DataType = Data; + using TargetType = example::NoTarget; + + Example() = default; + /* implicit */ Example(Data data) : data(std::move(data)) {} + + // When a DataLoader returns an Example like this, that example should be + // implicitly convertible to the underlying data type. + + operator Data&() { + return data; + } + operator const Data&() const { + return data; + } + + Data data; +}; + +using TensorExample = Example; +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..94293c452d53c44489fdfbbe6f0ee4d375aba9d8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h @@ -0,0 +1,178 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace data { +namespace detail { +// For increased safety and more separated logic, this implementation of +// `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A +// `ValidIterator` yields new batches until the `DataLoader` is exhausted. While +// the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are +// the same object. When the `ValidIterator` becomes exhausted, it compares +// equal to the `SentinelIterator`, but not before. Half the code here is to +// implement double dispatch for the comparison. Got damnit, C++. + +template +struct ValidIterator; + +template +struct SentinelIterator; + +/// Base class for the `ValidIterator` and `SentinelIterator` +template +struct IteratorImpl { + virtual ~IteratorImpl() = default; + virtual void next() = 0; + virtual Batch& get() = 0; + virtual bool operator==(const IteratorImpl& other) const = 0; + virtual bool operator==(const ValidIterator& other) const = 0; + virtual bool operator==(const SentinelIterator& other) const = 0; +}; + +template +struct ValidIterator : public IteratorImpl { + using BatchProducer = std::function()>; + + explicit ValidIterator(BatchProducer next_batch) + : next_batch_(std::move(next_batch)) {} + + /// Fetches the next batch. + void next() override { + // If we didn't get the very first batch yet, get it now. + lazy_initialize(); + TORCH_CHECK( + batch_.has_value(), "Attempted to increment iterator past the end"); + // Increment to the next batch. + batch_ = next_batch_(); + } + + /// Returns the current batch. The precondition for this operation to not + /// throw an exception is that it has been compared to the `SentinelIterator` + /// and did not compare equal. + Batch& get() override { + // If we didn't get the very first batch yet, get it now. + lazy_initialize(); + TORCH_CHECK( + batch_.has_value(), + "Attempted to dereference iterator that was past the end"); + return batch_.value(); + } + + /// Does double dispatch. + bool operator==(const IteratorImpl& other) const override { + return other == *this; + } + + /// A `ValidIterator` is equal to the `SentinelIterator` iff. the + /// `ValidIterator` has reached the end of the dataloader. + bool operator==(const SentinelIterator& /* unused */) const override { + lazy_initialize(); + return !batch_; + } + + /// Returns true if the memory address of `other` equals that of `this`. + bool operator==(const ValidIterator& other) const override { + return &other == this; + } + + /// Gets the very first batch if it has not yet been fetched. + void lazy_initialize() const { + if (!initialized_) { + batch_ = next_batch_(); + initialized_ = true; + } + } + + BatchProducer next_batch_; + mutable std::optional batch_; + mutable bool initialized_ = false; +}; + +template +struct SentinelIterator : public IteratorImpl { + void next() override { + AT_ERROR( + "Incrementing the DataLoader's past-the-end iterator is not allowed"); + } + + Batch& get() override { + AT_ERROR( + "Dereferencing the DataLoader's past-the-end iterator is not allowed"); + } + + /// Does double dispatch. + bool operator==(const IteratorImpl& other) const override { + return other == *this; + } + + /// Calls the comparison operator between `ValidIterator` and + /// `SentinelIterator`. + bool operator==(const ValidIterator& other) const override { + return other == *this; + } + + /// Sentinel iterators always compare equal. + bool operator==(const SentinelIterator& other) const override { + return true; + } +}; +} // namespace detail + +template +class Iterator { + public: + // Type aliases to make the class recognized as a proper iterator. + using difference_type = std::ptrdiff_t; + using value_type = Batch; + using pointer = Batch*; + using reference = Batch&; + using iterator_category = std::input_iterator_tag; + + explicit Iterator(std::unique_ptr> impl) + : impl_(std::move(impl)) {} + + /// Increments the iterator. + /// Only permitted for valid iterators (not past the end). + Iterator& operator++() { + impl_->next(); + return *this; + } + + /// Returns the current batch. + /// Only permitted for valid iterators (not past the end). + Batch& operator*() { + return impl_->get(); + } + + /// Returns a pointer to the current batch. + /// Only permitted for valid iterators (not past the end). + Batch* operator->() { + return &impl_->get(); + } + + /// Compares two iterators for equality. + bool operator==(const Iterator& other) const { + return *impl_ == *other.impl_; + } + + /// Compares two iterators for inequality. + bool operator!=(const Iterator& other) const { + return !(*this == other); + } + + private: + /// Points either to a `ValidIterator` or to a `SentinelIterator`. + std::shared_ptr> impl_; +}; +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h new file mode 100644 index 0000000000000000000000000000000000000000..928a2412aa76f8a22574b433a2f61152c45ae5c7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h new file mode 100644 index 0000000000000000000000000000000000000000..8ab48d9d5931f5b86dbb0674043c6c8b658adba3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace data { +namespace samplers { +/// A `Sampler` is an object that yields an index with which to access a +/// dataset. +template > +class Sampler { + public: + using BatchRequestType = BatchRequest; + + virtual ~Sampler() = default; + + /// Resets the `Sampler`'s internal state. + /// Typically called before a new epoch. + /// Optionally, accepts a new size when reseting the sampler. + virtual void reset(std::optional new_size) = 0; + + /// Returns the next index if possible, or an empty optional if the + /// sampler is exhausted for this epoch. + virtual std::optional next(size_t batch_size) = 0; + + /// Serializes the `Sampler` to the `archive`. + virtual void save(serialize::OutputArchive& archive) const = 0; + + /// Deserializes the `Sampler` from the `archive`. + virtual void load(serialize::InputArchive& archive) = 0; +}; + +} // namespace samplers +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h new file mode 100644 index 0000000000000000000000000000000000000000..a5247b008d75021c3627d1a3bc922072c96f7812 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace torch { +namespace data { +namespace samplers { +/// A base class for custom index types. +struct TORCH_API CustomBatchRequest { + CustomBatchRequest() = default; + CustomBatchRequest(const CustomBatchRequest&) = default; + CustomBatchRequest(CustomBatchRequest&&) noexcept = default; + virtual ~CustomBatchRequest() = default; + + /// The number of elements accessed by this index. + virtual size_t size() const = 0; +}; +} // namespace samplers +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h new file mode 100644 index 0000000000000000000000000000000000000000..bce36aaa4df719be2747fc3e8e5981c9aca20f16 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace data { +namespace samplers { + +/// A `Sampler` that selects a subset of indices to sample from and defines a +/// sampling behavior. In a distributed setting, this selects a subset of the +/// indices depending on the provided num_replicas and rank parameters. The +/// `Sampler` performs a rounding operation based on the `allow_duplicates` +/// parameter to decide the local sample count. +template > +class DistributedSampler : public Sampler { + public: + DistributedSampler( + size_t size, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true) + : size_(size), + num_replicas_(num_replicas), + rank_(rank), + epoch_(0), + allow_duplicates_(allow_duplicates) {} + + /// Set the epoch for the current enumeration. This can be used to alter the + /// sample selection and shuffling behavior. + void set_epoch(size_t epoch) { + epoch_ = epoch; + } + + size_t epoch() const { + return epoch_; + } + + protected: + size_t local_sample_count() { + if (allow_duplicates_) { + return (size_ + num_replicas_ - 1) / num_replicas_; + } else { + return size_ / num_replicas_; + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t size_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t num_replicas_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t rank_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t epoch_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool allow_duplicates_; +}; + +/// Select samples randomly. The sampling order is shuffled at each `reset()` +/// call. +class TORCH_API DistributedRandomSampler : public DistributedSampler<> { + public: + DistributedRandomSampler( + size_t size, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true); + + /// Resets the `DistributedRandomSampler` to a new set of indices. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `DistributedRandomSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `DistributedRandomSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `DistributedRandomSampler`. + size_t index() const noexcept; + + private: + void populate_indices(); + + size_t begin_index_; + size_t end_index_; + size_t sample_index_; + std::vector all_indices_; +}; + +/// Select samples sequentially. +class TORCH_API DistributedSequentialSampler : public DistributedSampler<> { + public: + DistributedSequentialSampler( + size_t size, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true); + + /// Resets the `DistributedSequentialSampler` to a new set of indices. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `DistributedSequentialSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `DistributedSequentialSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `DistributedSequentialSampler`. + size_t index() const noexcept; + + private: + void populate_indices(); + + size_t begin_index_; + size_t end_index_; + size_t sample_index_; + std::vector all_indices_; +}; + +} // namespace samplers +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h new file mode 100644 index 0000000000000000000000000000000000000000..4b023b6c703affedad7ccb3134456f437d97de1b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace data { +namespace samplers { + +/// A `Sampler` that returns random indices. +class TORCH_API RandomSampler : public Sampler<> { + public: + /// Constructs a `RandomSampler` with a size and dtype for the stored indices. + /// + /// The constructor will eagerly allocate all required indices, which is the + /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored + /// indices. You can change it to influence memory usage. + explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64); + + ~RandomSampler() override; + + /// Resets the `RandomSampler` to a new set of indices. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `RandomSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `RandomSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `RandomSampler`. + size_t index() const noexcept; + + private: + at::Tensor indices_; + int64_t index_ = 0; +}; +} // namespace samplers +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h new file mode 100644 index 0000000000000000000000000000000000000000..252ecc3ad3d7503ad37df2fed9d41edcd0235936 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace data { +namespace samplers { + +/// A `Sampler` that returns indices sequentially. +class TORCH_API SequentialSampler : public Sampler<> { + public: + /// Creates a `SequentialSampler` that will return indices in the range + /// `0...size - 1`. + explicit SequentialSampler(size_t size); + + /// Resets the `SequentialSampler` to zero. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `SequentialSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `SequentialSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `SequentialSampler`. + size_t index() const noexcept; + + private: + size_t size_; + size_t index_{0}; +}; + +} // namespace samplers +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..7585217a9cf260a67eb4d4fbda061a27a3fb23af --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace torch { +namespace data { +namespace samplers { +/// Serializes a `Sampler` into an `OutputArchive`. +template +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const Sampler& sampler) { + sampler.save(archive); + return archive; +} + +/// Deserializes a `Sampler` from an `InputArchive`. +template +serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + Sampler& sampler) { + sampler.load(archive); + return archive; +} +} // namespace samplers +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h new file mode 100644 index 0000000000000000000000000000000000000000..201c914e49e5cf2d34e0f273cc71fb6d02e4068c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch { +namespace serialize { +class InputArchive; +class OutputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace data { +namespace samplers { + +/// A wrapper around a batch size value, which implements the +/// `CustomBatchRequest` interface. +struct TORCH_API BatchSize : public CustomBatchRequest { + explicit BatchSize(size_t size); + size_t size() const noexcept override; + operator size_t() const noexcept; + size_t size_; +}; + +/// A sampler for (potentially infinite) streams of data. +/// +/// The major feature of the `StreamSampler` is that it does not return +/// particular indices, but instead only the number of elements to fetch from +/// the dataset. The dataset has to decide how to produce those elements. +class TORCH_API StreamSampler : public Sampler { + public: + /// Constructs the `StreamSampler` with the number of individual examples that + /// should be fetched until the sampler is exhausted. + explicit StreamSampler(size_t epoch_size); + + /// Resets the internal state of the sampler. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns a `BatchSize` object with the number of elements to fetch in the + /// next batch. This number is the minimum of the supplied `batch_size` and + /// the difference between the `epoch_size` and the current index. If the + /// `epoch_size` has been reached, returns an empty optional. + std::optional next(size_t batch_size) override; + + /// Serializes the `StreamSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `StreamSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + private: + size_t examples_retrieved_so_far_ = 0; + size_t epoch_size_; +}; + +} // namespace samplers +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h new file mode 100644 index 0000000000000000000000000000000000000000..e5d92062e62d52dd2dac3ab39f76385f9bf1522f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h new file mode 100644 index 0000000000000000000000000000000000000000..0bc1f2ea7b141a270a7b5f826442d96d9841e72c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#include +#include + +namespace torch { +namespace data { +namespace transforms { + +/// A transformation of a batch to a new batch. +template +class BatchTransform { + public: + using InputBatchType = InputBatch; + using OutputBatchType = OutputBatch; + + virtual ~BatchTransform() = default; + + /// Applies the transformation to the given `input_batch`. + virtual OutputBatch apply_batch(InputBatch input_batch) = 0; +}; + +/// A transformation of individual input examples to individual output examples. +/// +/// Just like a `Dataset` is a `BatchDataset`, a `Transform` is a +/// `BatchTransform` that can operate on the level of individual examples rather +/// than entire batches. The batch-level transform is implemented (by default) +/// in terms of the example-level transform, though this can be customized. +template +class Transform + : public BatchTransform, std::vector> { + public: + using InputType = Input; + using OutputType = Output; + + /// Applies the transformation to the given `input`. + virtual OutputType apply(InputType input) = 0; + + /// Applies the `transformation` over the entire `input_batch`. + std::vector apply_batch(std::vector input_batch) override { + std::vector output_batch; + output_batch.reserve(input_batch.size()); + for (auto&& input : input_batch) { + output_batch.push_back(apply(std::move(input))); + } + return output_batch; + } +}; +} // namespace transforms +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h new file mode 100644 index 0000000000000000000000000000000000000000..181bcae0031b6f7f0f77b5eebeb7f3de5436c691 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +#include + +namespace torch { +namespace data { +namespace transforms { + +/// A `Collation` is a transform that reduces a batch into a single value. +/// The result is a `BatchDataset` that has the type of the single value as its +/// `BatchType`. +template > +using Collation = BatchTransform; + +/// A `Collate` allows passing a custom function to reduce/collate a batch +/// into a single value. It's effectively the lambda version of `Collation`, +/// which you could subclass and override `operator()` to achieve the same. +/// +/// \rst +/// .. code-block:: cpp +/// using namespace torch::data; +/// +/// auto dataset = datasets::MNIST("path/to/mnist") +/// .map(transforms::Collate>([](std::vector> e) { +/// return std::move(e.front()); +/// })); +/// \endrst +template > +using Collate = BatchLambda; +} // namespace transforms +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h new file mode 100644 index 0000000000000000000000000000000000000000..252b29807a8efe5e16fa11b9f9337c9f8a4ffa98 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch { +namespace data { +namespace transforms { + +/// A `BatchTransform` that applies a user-provided functor to a batch. +template +class BatchLambda : public BatchTransform { + public: + using typename BatchTransform::InputBatchType; + using typename BatchTransform::OutputBatchType; + using FunctionType = std::function; + + /// Constructs the `BatchLambda` from the given `function` object. + explicit BatchLambda(FunctionType function) + : function_(std::move(function)) {} + + /// Applies the user-provided function object to the `input_batch`. + OutputBatchType apply_batch(InputBatchType input_batch) override { + return function_(std::move(input_batch)); + } + + private: + FunctionType function_; +}; + +// A `Transform` that applies a user-provided functor to individual examples. +template +class Lambda : public Transform { + public: + using typename Transform::InputType; + using typename Transform::OutputType; + using FunctionType = std::function; + + /// Constructs the `Lambda` from the given `function` object. + explicit Lambda(FunctionType function) : function_(std::move(function)) {} + + /// Applies the user-provided function object to the `input`. + OutputType apply(InputType input) override { + return function_(std::move(input)); + } + + private: + FunctionType function_; +}; + +} // namespace transforms +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..4be1bd920b71596b7a91dcae28acc2713c7af782 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace data { +namespace transforms { + +template > +struct Stack; + +/// A `Collation` for `Example` types that stacks all data +/// tensors into one tensor, and all target (label) tensors into one tensor. +template <> +struct Stack> : public Collation> { + Example<> apply_batch(std::vector> examples) override { + std::vector data, targets; + data.reserve(examples.size()); + targets.reserve(examples.size()); + for (auto& example : examples) { + data.push_back(std::move(example.data)); + targets.push_back(std::move(example.target)); + } + return {torch::stack(data), torch::stack(targets)}; + } +}; + +/// A `Collation` for `Example` types that stacks all data +/// tensors into one tensor. +template <> +struct Stack + : public Collation> { + TensorExample apply_batch(std::vector examples) override { + std::vector data; + data.reserve(examples.size()); + for (auto& example : examples) { + data.push_back(std::move(example.data)); + } + return torch::stack(data); + } +}; +} // namespace transforms +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..2e135c528131506b0249db41f9de56e7cfc84e03 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace data { +namespace transforms { + +/// A `Transform` that is specialized for the typical `Example` +/// combination. It exposes a single `operator()` interface hook (for +/// subclasses), and calls this function on input `Example` objects. +template +class TensorTransform + : public Transform, Example> { + public: + using E = Example; + using typename Transform::InputType; + using typename Transform::OutputType; + + /// Transforms a single input tensor to an output tensor. + virtual Tensor operator()(Tensor input) = 0; + + /// Implementation of `Transform::apply` that calls `operator()`. + OutputType apply(InputType input) override { + input.data = (*this)(std::move(input.data)); + return input; + } +}; + +/// A `Lambda` specialized for the typical `Example` input type. +template +class TensorLambda : public TensorTransform { + public: + using FunctionType = std::function; + + /// Creates a `TensorLambda` from the given `function`. + explicit TensorLambda(FunctionType function) + : function_(std::move(function)) {} + + /// Applies the user-provided functor to the input tensor. + Tensor operator()(Tensor input) override { + return function_(std::move(input)); + } + + private: + FunctionType function_; +}; + +/// Normalizes input tensors by subtracting the supplied mean and dividing by +/// the given standard deviation. +template +struct Normalize : public TensorTransform { + /// Constructs a `Normalize` transform. The mean and standard deviation can be + /// anything that is broadcastable over the input tensors (like single + /// scalars). + Normalize(ArrayRef mean, ArrayRef stddev) + : mean(torch::tensor(mean, torch::kFloat32) + .unsqueeze(/*dim=*/1) + .unsqueeze(/*dim=*/2)), + stddev(torch::tensor(stddev, torch::kFloat32) + .unsqueeze(/*dim=*/1) + .unsqueeze(/*dim=*/2)) {} + + torch::Tensor operator()(Tensor input) override { + return input.sub(mean).div(stddev); + } + + torch::Tensor mean, stddev; +}; +} // namespace transforms +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h new file mode 100644 index 0000000000000000000000000000000000000000..40680b8330c456669826e8957abcbd3ae15c130c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace data { + +/// An exception thrown when a DataLoader's worker thread throws an exception, +/// which is caught. A `WorkerException` stores an `exception_ptr` to the +/// original exception thrown in the worker thread. +struct WorkerException : public std::exception { + /// Constructs a `WorkerException` from an `exception_ptr`. + explicit WorkerException(std::exception_ptr original) + : original_exception(std::move(original)), + message("Caught exception in DataLoader worker thread.") { + try { + std::rethrow_exception(original_exception); + } catch (std::exception& e) { + message += " Original message: "; + message += e.what(); + } + } + + const char* what() const noexcept override { + return message.c_str(); + } + + /// The original exception thrown in the worker thread. + std::exception_ptr original_exception; + + /// This exception's message (not the original exception's message). + std::string message; +}; + +} // namespace data +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h new file mode 100644 index 0000000000000000000000000000000000000000..4da7cb1f4460f1a95d9c614610abe3bdf63874c8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h @@ -0,0 +1,363 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +#include + +namespace torch { + +namespace detail { + +enum class TensorDataContainerType { Scalar, InitList, Tensor }; + +struct TensorDataContainer; + +inline std::ostream& operator<<( + std::ostream& stream, + const TensorDataContainer& tensor_data_container); + +inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) { + if (scalar_type == at::kInt || scalar_type == at::kLong) { + // C++ `torch::tensor` with an integer type or an `at::ArrayRef` / + // `std::vector` / (nested) braced-init-list of integer types always + // produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python + // `torch.tensor` behavior. + return at::kLong; + } else if (scalar_type == at::kFloat || scalar_type == at::kDouble) { + // C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / + // `std::vector` / (nested) braced-init-list of floating-point types always + // produces a tensor of dtype `torch::get_default_dtype()`, matching Python + // `torch.tensor` behavior. + return at::typeMetaToScalarType(at::get_default_dtype()); + } else { + return scalar_type; + } +} + +// We use `TensorDataContainer` to support converting the following data +// container types into the equivalent Tensor: +// +// 1. Arbitrarily nested braced-init-list (e.g. `{{1, 2}, {3, 4}}`). +// 2. `at::ArrayRef` of supported tensor data types. +// 3. `std::vector` of supported tensor data types. +// +// At any time, a `TensorDataContainer` object represents one of the following: +// +// 1. A scalar with value `scalar()` and type `scalar_type()`. +// 2. A Tensor represented in `std::initializer_list` form, +// with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor +// sizes `sizes()`. +// 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar +// type `scalar_type()`, +// and Tensor sizes `sizes()`. +// +// All the infrastructure here is mostly to support converting an arbitrarily +// nested braced-init-list to the equivalent Tensor successfully. Consider the +// following example: +// +// `torch::tensor({{1}, {2}})` +// +// this will call into the `torch::tensor` function: +// +// `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const +// at::TensorOptions& options = {})` +// +// the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer` +// type: +// +// `TensorDataContainer({{1}, {2}})` +// +// which matches to the +// `TensorDataContainer(std::initializer_list)` +// constructor, and in an attempt to convert `{1}` and `{2}` to +// `TensorDataContainer`, it calls the following: +// +// `TensorDataContainer({1})` (same call path happens for `{2}`, and we'll just +// focus on `{1}` here) +// +// At this point, theoretically there are two plausible ways for `{1}` to be +// matched to one of the constructors of `TensorDataContainer`: +// +// 1. It can be a list-initialization of a scalar value, thus matching +// `TensorDataContainer(int value)`. +// 2. It can be converted to `std::initializer_list`, thus +// matching +// `TensorDataContainer(std::initializer_list)`. +// +// How does the compiler decide which one to choose? According to +// `https://en.cppreference.com/w/cpp/language/list_initialization`, +// braced-init-list always prefers the constructor that takes +// `std::initializer_list`. Hence we happily move forward with constructor #2, +// and it calls the following: +// +// `TensorDataContainer(1)` +// +// Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar +// value. All is good. +struct TensorDataContainer { + // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{}, + // {}})`), the innermost empty braced-init-list `{}` matches the default + // constructor of the innermost `TensorDataContainer`. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + TensorDataContainer() + : sizes_({0}), + // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g. + // `torch.tensor([[], []])`) depends on the value of + // `torch.get_default_dtype()`, and we should do the same for the C++ + // equivalent. + scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())), + type_(TensorDataContainerType::InitList) {} +#define TENSOR(T, S) \ + TensorDataContainer(T value) \ + : sizes_(), \ + scalar_type_(at::k##S), \ + type_(TensorDataContainerType::Scalar), \ + scalar_(value) {} + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + TensorDataContainer(std::initializer_list init_list) + : sizes_(), + scalar_type_(init_list.begin()->scalar_type()), + type_(TensorDataContainerType::InitList), + init_list_(init_list) { + const TensorDataContainer& first_elem = *(init_list.begin()); + for (const auto& elem : init_list) { + TORCH_CHECK( + elem.sizes() == first_elem.sizes(), + "Expected all sub-lists to have sizes: ", + first_elem.sizes(), + " (e.g. ", + first_elem, + "), ", + "but got sub-list ", + elem, + " with sizes: ", + elem.sizes()); + TORCH_CHECK( + elem.scalar_type() == first_elem.scalar_type(), + "Expected all elements of the tensor to have the same scalar type: ", + first_elem.scalar_type(), + ", but got element of scalar type: ", + elem.scalar_type()); + } + sizes_.reserve(first_elem.sizes().size() + 1); + sizes_.push_back(init_list.size()); + sizes_.insert( + sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end()); + } + +#define TENSOR(T, S) \ + TensorDataContainer(at::ArrayRef values) \ + : sizes_({(int64_t)values.size()}), \ + scalar_type_(at::k##S), \ + type_(TensorDataContainerType::Tensor) { \ + at::AutoDispatchBelowAutograd mode; \ + if (scalar_type_ == at::kBool) { \ + tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \ + } else { \ + tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \ + } \ + } + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR + + // NOTE: We need to handle `std::vector` explicitly instead of relying on an + // implicit conversion to `at::ArrayRef`, otherwise the following error can be + // thrown when calling `torch::tensor(std::vector({1, 2}))`: + // ``` + // error: no matching function for call to 'tensor(const std::vector&)' + // no known conversion for argument 1 from 'const std::vector' to + // 'torch::detail::TensorDataContainer' + // ``` + // + // NOTE: `torch::tensor(std::vector)` is not supported for now, because + // ArrayRef cannot be constructed from a std::vector bitfield. +#define TENSOR(T, S) \ + TensorDataContainer(const std::vector& values) \ + : TensorDataContainer(at::ArrayRef(values)) {} + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR) + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR + + bool is_scalar() const { + return type_ == TensorDataContainerType::Scalar; + } + + const c10::Scalar& scalar() const { + TORCH_CHECK( + is_scalar(), + "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); + return scalar_; + } + + bool is_init_list() const { + return type_ == TensorDataContainerType::InitList; + } + + const std::initializer_list& init_list() const { + TORCH_CHECK( + is_init_list(), + "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); + return init_list_; + } + + bool is_tensor() const { + return type_ == TensorDataContainerType::Tensor; + } + + const at::Tensor& tensor() const { + TORCH_CHECK( + is_tensor(), + "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); + return tensor_; + } + + const std::vector& sizes() const { + return sizes_; + } + + const c10::ScalarType& scalar_type() const { + return scalar_type_; + } + + at::Tensor convert_to_tensor(at::TensorOptions options) const { + if (!options.has_dtype()) { + options = options.dtype(compute_desired_dtype(scalar_type_)); + } + + if (is_scalar()) { + at::AutoDispatchBelowAutograd mode; + return at::scalar_tensor(scalar_, options); + } else if (is_init_list()) { + // NOTE: Here we explicitly choose to initialize the tensor on CPU first, + // fill each element of the tensor, and then move the tensor to the + // desired device. For CUDA device, this approach only involves 1 CUDA + // kernel launch, and is much faster than initializing the tensor on CUDA + // first and then filling each element of it (which involves `N` CUDA + // kernel launches where `N` is the number of the elements in the tensor). + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd mode; + return at::empty(sizes_, options.device(at::kCPU)); + })(); + fill_tensor(tensor); + return tensor.to(options.device()); + } else if (is_tensor()) { + auto output = tensor_.to(options); + TORCH_CHECK( + !tensor_.is_complex() || output.is_complex(), + "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information"); + return output; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); + } + } + + void pretty_print_recursive(std::ostream& stream) const { + if (is_scalar()) { + AT_DISPATCH_ALL_TYPES_AND3( + at::kBool, + at::kHalf, + at::kBFloat16, + scalar_type_, + "TensorDataContainer_pretty_print_scalar", + [&] { stream << scalar_.to(); }); + } else if (is_init_list()) { + stream << "{"; + for (const TensorDataContainer* it = init_list_.begin(); + it != init_list_.end(); + it++) { + stream << *it; + if (std::next(it) != init_list_.end()) + stream << ", "; + } + stream << "}"; + } else if (is_tensor()) { + stream << "{"; + for (const auto i : c10::irange(tensor_.sizes()[0])) { + AT_DISPATCH_ALL_TYPES_AND3( + at::kBool, + at::kHalf, + at::kBFloat16, + scalar_type_, + "TensorDataContainer_pretty_print_tensor_item", + [&] { stream << tensor_[i].item(); }); + if (i != tensor_.sizes()[0] - 1) + stream << ", "; + } + stream << "}"; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); + } + } + + private: + void fill_tensor(at::Tensor& tensor) const { + if (is_scalar()) { + TORCH_INTERNAL_ASSERT( + tensor.dim() == 0, + "Expected a 0-dim Tensor, but got Tensor with dimensions: ", + tensor.dim()); + at::NoGradGuard guard; + tensor.fill_(scalar_); + } else if (is_init_list()) { + TORCH_INTERNAL_ASSERT( + tensor.sizes()[0] == (int64_t)init_list_.size(), + "Expected a Tensor with size ", + init_list_.size(), + " in its first dimension, but got Tensor with size ", + tensor.sizes()[0], + " in its first dimension"); + size_t index = 0; + for (const auto& elem : init_list_) { + at::Tensor slice = tensor[index]; + elem.fill_tensor(slice); + index++; + } + } else if (is_tensor()) { + TORCH_INTERNAL_ASSERT( + false, + "TensorDataContainer is already a Tensor type, `fill_tensor` should not be called"); + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); + } + } + + std::vector sizes_; + c10::ScalarType scalar_type_; + TensorDataContainerType type_; + c10::Scalar scalar_; + std::initializer_list init_list_; + at::Tensor tensor_; +}; + +inline std::ostream& operator<<( + std::ostream& stream, + const TensorDataContainer& tensor_data_container) { + tensor_data_container.pretty_print_recursive(stream); + return stream; +} + +} // namespace detail + +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h new file mode 100644 index 0000000000000000000000000000000000000000..c85fc7fff4b4d56171c6add8f82ea99ba74242bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { +namespace nn { +class Module; +} // namespace nn +} // namespace torch + +namespace torch { +namespace detail { +/// Detects if a type T has a forward() method. +template +struct has_forward { + // Declare two types with differing size. + using yes = int8_t; + using no = int16_t; + + // Here we declare two functions. The first is only enabled if `&U::forward` + // is well-formed and returns the `yes` type. In C++, the ellipsis parameter + // type (`...`) always puts the function at the bottom of overload resolution. + // This is specified in the standard as: 1) A standard conversion sequence is + // always better than a user-defined conversion sequence or an ellipsis + // conversion sequence. 2) A user-defined conversion sequence is always better + // than an ellipsis conversion sequence This means that if the first overload + // is viable, it will be preferred over the second as long as we pass any + // convertible type. The type of `&U::forward` is a pointer type, so we can + // pass e.g. 0. + template + static yes test(decltype(&U::forward)); + template + static no test(...); + + // Finally we test statically whether the size of the type returned by the + // selected overload is the size of the `yes` type. + static constexpr bool value = (sizeof(test(nullptr)) == sizeof(yes)); +}; + +template +constexpr bool check_not_lvalue_references() { + return (!std::is_lvalue_reference::value || + std::is_const::type>::value) && + check_not_lvalue_references(); +} + +template <> +inline constexpr bool check_not_lvalue_references() { + return true; +} + +/// A type trait whose `value` member is true if `M` derives from `Module`. +template +using is_module = + std::is_base_of::type>; + +template +using enable_if_module_t = + typename std::enable_if::value, T>::type; +} // namespace detail +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h new file mode 100644 index 0000000000000000000000000000000000000000..ab4a589aeded124cb6e34fd5dd191165f4aea43c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h @@ -0,0 +1,372 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Stores a type erased `Module`. +/// +/// The PyTorch C++ API does not impose an interface on the signature of +/// `forward()` in `Module` subclasses. This gives you complete freedom to +/// design your `forward()` methods to your liking. However, this also means +/// there is no unified base type you could store in order to call `forward()` +/// polymorphically for any module. This is where the `AnyModule` comes in. +/// Instead of inheritance, it relies on type erasure for polymorphism. +/// +/// An `AnyModule` can store any `nn::Module` subclass that provides a +/// `forward()` method. This `forward()` may accept any types and return any +/// type. Once stored in an `AnyModule`, you can invoke the underlying module's +/// `forward()` by calling `AnyModule::forward()` with the arguments you would +/// supply to the stored module (though see one important limitation below). +/// Example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// struct GenericTrainer { +/// torch::nn::AnyModule module; +/// +/// void train(torch::Tensor input) { +/// module.forward(input); +/// } +/// }; +/// +/// GenericTrainer trainer1{torch::nn::Linear(3, 4)}; +/// GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)}; +/// \endrst +/// +/// As `AnyModule` erases the static type of the stored module (and its +/// `forward()` method) to achieve polymorphism, type checking of arguments is +/// moved to runtime. That is, passing an argument with an incorrect type to an +/// `AnyModule` will compile, but throw an exception at runtime: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); +/// // Linear takes a tensor as input, but we are passing an integer. +/// // This will compile, but throw a `torch::Error` exception at runtime. +/// module.forward(123); +/// \endrst +/// +/// \rst +/// .. attention:: +/// One noteworthy limitation of `AnyModule` is that its `forward()` method +/// does not support implicit conversion of argument types. For example, if +/// the stored module's `forward()` method accepts a `float` and you call +/// `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw +/// an exception. +/// \endrst +/// +/// The return type of the `AnyModule`'s `forward()` method is controlled via +/// the first template argument to `AnyModule::forward()`. It defaults to +/// `torch::Tensor`. To change it, you can write `any_module.forward()`, +/// for example. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); +/// auto output = module.forward(torch::ones({2, 3})); +/// +/// struct IntModule { +/// int forward(int x) { return x; } +/// }; +/// torch::nn::AnyModule module(IntModule{}); +/// int output = module.forward(5); +/// \endrst +/// +/// The only other method an `AnyModule` provides access to on the stored +/// module is `clone()`. However, you may acquire a handle on the module via +/// `.ptr()`, which returns a `shared_ptr`. Further, if you know +/// the concrete type of the stored module, you can get a concrete handle to it +/// using `.get()` where `T` is the concrete module type. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); +/// std::shared_ptr ptr = module.ptr(); +/// torch::nn::Linear linear(module.get()); +/// \endrst +class AnyModule { + public: + /// A default-constructed `AnyModule` is in an empty state. + AnyModule() = default; + + /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object. + template + explicit AnyModule(std::shared_ptr module); + + /// Constructs an `AnyModule` from a concrete module object. + template < + typename ModuleType, + typename = torch::detail::enable_if_module_t> + explicit AnyModule(ModuleType&& module); + + /// Constructs an `AnyModule` from a module holder. + template + explicit AnyModule(const ModuleHolder& module_holder); + + /// Move construction and assignment is allowed, and follows the default + /// behavior of move for `std::unique_ptr`. + AnyModule(AnyModule&&) = default; + AnyModule& operator=(AnyModule&&) = default; + + /// Creates a shallow copy of an `AnyModule`. + AnyModule(const AnyModule& other); + AnyModule& operator=(const AnyModule& other); + + /// Creates a deep copy of an `AnyModule` if it contains a module, else an + /// empty `AnyModule` if it is empty. + AnyModule clone(std::optional device = std::nullopt) const; + + /// Assigns a module to the `AnyModule` (to circumvent the explicit + /// constructor). + template + AnyModule& operator=(std::shared_ptr module); + + /// Invokes `forward()` on the contained module with the given arguments, and + /// returns the return value as an `AnyValue`. Use this method when chaining + /// `AnyModule`s in a loop. + template + AnyValue any_forward(ArgumentTypes&&... arguments); + + /// Invokes `forward()` on the contained module with the given arguments, and + /// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults + /// to `torch::Tensor`). + template + ReturnType forward(ArgumentTypes&&... arguments); + + /// Attempts to cast the underlying module to the given module type. Throws an + /// exception if the types do not match. + template > + T& get(); + + /// Attempts to cast the underlying module to the given module type. Throws an + /// exception if the types do not match. + template > + const T& get() const; + + /// Returns the contained module in a `nn::ModuleHolder` subclass if possible + /// (i.e. if `T` has a constructor for the underlying module type). + template + T get() const; + + /// Returns a `std::shared_ptr` whose dynamic type is that of the underlying + /// module. + std::shared_ptr ptr() const; + + /// Like `ptr()`, but casts the pointer to the given type. + template > + std::shared_ptr ptr() const; + + /// Returns the `type_info` object of the contained value. + const std::type_info& type_info() const; + + /// Returns true if the `AnyModule` does not contain a module. + bool is_empty() const noexcept; + + private: + /// Creates a `unique_ptr` pointing to a + /// `AnyModuleHolder` of the correct type. This method is used to deduce the + /// arguments of the module's `forward()` method. + template < + typename ModuleType, + typename Class, + typename ReturnType, + typename... ArgumentTypes> + std::unique_ptr make_holder( + std::shared_ptr&& module, + ReturnType (Class::*)(ArgumentTypes...)); + + /// Helper method invoked by const and non-const `get()`. + template + ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const; + + /// Helper method invoked by const and non-const `get()`. + template + ModuleType& get_() const; + + /// The type erased module. + std::unique_ptr content_; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +AnyModule::AnyModule(std::shared_ptr module) + : content_(make_holder( + std::move(module), + &std::remove_reference::type::forward)) { + // `AnyModule` can only store an `nn::Module` subclass object that provides + // a `forward()` method that has a non-templatized return type. + // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s + // `forward()` method has a templatized return type.) + static_assert( + torch::detail::is_module::value, + "Can only store object derived from nn::Module into AnyModule"); + static_assert( + torch::detail::has_forward::value, + "Can only store module with a forward() method that has a non-templatized" + " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential" + "into AnyModule, because its forward() method's argument type and return type are templatized." + " If you need to use nn::Sequentials inside each other you can subclass " + "nn::Sequential and write a non-templatized forward function for it. You can checkout " + "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 " + "for an example on how to do this.)."); +} + +template +AnyModule::AnyModule(ModuleType&& module) + : AnyModule( + std::make_shared(std::forward(module))) {} + +template +AnyModule::AnyModule(const ModuleHolder& module_holder) + : AnyModule(module_holder.ptr()) {} + +inline AnyModule::AnyModule(const AnyModule& other) + : content_(other.content_ ? other.content_->copy() : nullptr) {} + +inline AnyModule& AnyModule::operator=(const AnyModule& other) { + if (this != &other) { + content_ = other.content_ ? other.content_->copy() : nullptr; + } + return *this; +} + +inline AnyModule AnyModule::clone(std::optional device) const { + AnyModule clone; + clone.content_ = content_ ? content_->clone_module(device) : nullptr; + return clone; +} + +template +AnyModule& AnyModule::operator=(std::shared_ptr module) { + // NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature) + return (*this = AnyModule(std::move(module))); +} + +template +AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) { + TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule"); + std::vector values; + values.reserve(sizeof...(ArgumentTypes)); + torch::apply( + [&values](AnyValue&& value) { values.push_back(std::move(value)); }, + AnyValue(std::forward(arguments))...); + return content_->forward(std::move(values)); +} + +template +ReturnType AnyModule::forward(ArgumentTypes&&... arguments) { + return any_forward(std::forward(arguments)...) + .template get(); +} + +template +T& AnyModule::get() { + TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule"); + return get_(); +} + +template +const T& AnyModule::get() const { + TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule"); + return get_(); +} + +template +T AnyModule::get() const { + return T(ptr()); +} + +inline std::shared_ptr AnyModule::ptr() const { + TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule"); + return content_->ptr(); +} + +template +std::shared_ptr AnyModule::ptr() const { + TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule"); + // Call get() but discard the value, just to do the type checking. + get_(); + return std::dynamic_pointer_cast(ptr()); +} + +inline const std::type_info& AnyModule::type_info() const { + TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule"); + return content_->type_info; +} + +inline bool AnyModule::is_empty() const noexcept { + return content_ == nullptr; +} + +// Private Methods + +template < + typename ModuleType, + typename Class, + typename ReturnType, + typename... ArgumentTypes> +std::unique_ptr AnyModule::make_holder( + std::shared_ptr&& module, + ReturnType (Class::*)(ArgumentTypes...)) { + static_assert( + torch::detail::check_not_lvalue_references(), + "Modules stored inside AnyModule must not take references. " + "Use pointers instead."); + static_assert( + !std::is_void::value, + "AnyModule cannot store modules that return void " + "(you can return a dummy value)."); + return std::make_unique< + AnyModuleHolder, ArgumentTypes...>>( + std::move(module)); +} + +template +ModuleType& AnyModule::get_() const { + using M = typename std::remove_reference::type; + static_assert( + torch::detail::has_forward::value, + "Can only call AnyModule::get with a type T that has a forward method"); + return get_(&M::forward); +} + +template +ModuleType& AnyModule::get_( + ReturnType (ModuleType::*)(ArgumentTypes...)) const { + if (typeid(ModuleType).hash_code() == type_info().hash_code()) { + return *static_cast&>( + *content_) + .module; + } + AT_ERROR( + "Attempted to cast module of type ", + c10::demangle(type_info().name()), + " to type ", + c10::demangle(typeid(ModuleType).name())); +} + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h new file mode 100644 index 0000000000000000000000000000000000000000..edeb8e6b764c516a6fee5445a704905ac7df7c1b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h @@ -0,0 +1,133 @@ +#pragma once + +#include + +namespace torch { +namespace nn { + +class Module; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The static type of the object we store in the `AnyModule`, which erases +/// the actual type, but allows us to call `forward()` on the underlying +/// module. +struct AnyModulePlaceholder : public AnyValue::Placeholder { + using AnyValue::Placeholder::Placeholder; + + /// The "erased" `forward()` method. + virtual AnyValue forward(std::vector&& arguments) = 0; + + /// Returns std::shared_ptr pointing to the erased module. + virtual std::shared_ptr ptr() = 0; + + /// Returns a `AnyModulePlaceholder` with a shallow copy of this `AnyModule`. + virtual std::unique_ptr copy() const = 0; + + /// Returns a `AnyModulePlaceholder` with a deep copy of this `AnyModule`. + virtual std::unique_ptr clone_module( + std::optional device) const = 0; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The dynamic type of the object stored in the `AnyModule`. It contains the +/// concrete instance to which all calls are forwarded. It is parameterized +/// over the concrete type of the module, and the types of the arguments the +/// module takes in its `forward()` method. +template +struct AnyModuleHolder : public AnyModulePlaceholder { + /// \internal + struct CheckedGetter { + template + std::decay_t&& operator()(size_t index) { + AT_ASSERT(index < arguments_.size()); + auto& value = arguments_[index]; + if (auto* maybe_value = value.template try_get>()) { + return std::move(*maybe_value); + } + AT_ERROR( + "Expected argument #", + index, + " to be of type ", + c10::demangle(typeid(T).name()), + ", but received value of type ", + c10::demangle(value.type_info().name())); + } + std::vector& arguments_; + }; + + /// \internal + struct InvokeForward { + template + AnyValue operator()(Ts&&... ts) { + return AnyValue(module_->forward(std::forward(ts)...)); + } + std::shared_ptr& module_; + }; + + /// Constructs the `AnyModuleHolder` from a concrete module. + explicit AnyModuleHolder(std::shared_ptr&& module_) + : AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {} + + /// Calls `forward()` on the underlying module, casting each `AnyValue` in the + /// argument vector to a concrete value. + AnyValue forward(std::vector&& arguments) override { + if (module->_forward_has_default_args()) { + TORCH_CHECK( + arguments.size() >= module->_forward_num_required_args() && + arguments.size() <= sizeof...(ArgumentTypes), + c10::demangle(type_info.name()), + "'s forward() method expects at least ", + module->_forward_num_required_args(), + " argument(s) and at most ", + sizeof...(ArgumentTypes), + " argument(s), but received ", + arguments.size(), + "."); + arguments = std::move( + module->_forward_populate_default_args(std::move(arguments))); + } else { + std::string use_default_args_macro_prompt = " If " + + c10::demangle(type_info.name()) + + "'s forward() method has default arguments, " + + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."; + TORCH_CHECK( + arguments.size() == sizeof...(ArgumentTypes), + c10::demangle(type_info.name()), + "'s forward() method expects ", + sizeof...(ArgumentTypes), + " argument(s), but received ", + arguments.size(), + ".", + (arguments.size() < sizeof...(ArgumentTypes)) + ? use_default_args_macro_prompt + : ""); + } + + // FYI: During invocation of a module's `forward()` method, the values live + // in the `arguments` vector inside this function. + return torch::unpack( + InvokeForward{module}, CheckedGetter{arguments}); + } + + std::shared_ptr ptr() override { + return module; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + + std::unique_ptr clone_module( + std::optional device) const override { + return std::make_unique( + std::dynamic_pointer_cast(module->clone(device))); + } + + /// The actual concrete module instance. + std::shared_ptr module; +}; + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h new file mode 100644 index 0000000000000000000000000000000000000000..d154130618f2dcaaf4724c32bd40fb59ed3bd465 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// An implementation of `std::any` which stores +/// a type erased object, whose concrete value can be retrieved at runtime by +/// checking if the `typeid()` of a requested type matches the `typeid()` of +/// the object stored. +class AnyValue { + public: + /// Move construction and assignment is allowed, and follows the default + /// behavior of move for `std::unique_ptr`. + AnyValue(AnyValue&&) = default; + AnyValue& operator=(AnyValue&&) = default; + + /// Copy construction and assignment is allowed. + AnyValue(const AnyValue& other) : content_(other.content_->clone()) {} + AnyValue& operator=(const AnyValue& other) { + content_ = other.content_->clone(); + return *this; + } + + /// Constructs the `AnyValue` from value type. + template + // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) + explicit AnyValue(T&& value) + : content_( + std::make_unique>>(std::forward(value))) { + } + + /// Returns a pointer to the value contained in the `AnyValue` if the type + /// passed as template parameter matches the type of the value stored, and + /// returns a null pointer otherwise. + template + T* try_get() { + static_assert( + !std::is_reference::value, + "AnyValue stores decayed types, you cannot cast it to a reference type"); + static_assert( + !std::is_array::value, + "AnyValue stores decayed types, you must cast it to T* instead of T[]"); + if (typeid(T).hash_code() == type_info().hash_code()) { + return &static_cast&>(*content_).value; + } + return nullptr; + } + + /// Returns the value contained in the `AnyValue` if the type passed as + /// template parameter matches the type of the value stored, and throws an + /// exception otherwise. + template + T get() { + if (auto* maybe_value = try_get()) { + return *maybe_value; + } + AT_ERROR( + "Attempted to cast AnyValue to ", + c10::demangle(typeid(T).name()), + ", but its actual type is ", + c10::demangle(type_info().name())); + } + + /// Returns the `type_info` object of the contained value. + const std::type_info& type_info() const noexcept { + return content_->type_info; + } + + private: + friend struct AnyModulePlaceholder; + friend struct TestAnyValue; + + /// \internal + /// The static type of the object we store in the `AnyValue`, which erases the + /// actual object's type, allowing us only to check the `type_info` of the + /// type stored in the dynamic type. + struct Placeholder { + explicit Placeholder(const std::type_info& type_info_) noexcept + : type_info(type_info_) {} + Placeholder(const Placeholder&) = default; + Placeholder(Placeholder&&) = default; + virtual ~Placeholder() = default; + virtual std::unique_ptr clone() const { + TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`"); + } + const std::type_info& type_info; + }; + + /// \internal + /// The dynamic type of the object we store in the `AnyValue`, which hides the + /// actual object we have erased in this `AnyValue`. + template + struct Holder : public Placeholder { + /// A template because T&& would not be universal reference here. + template + // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) + explicit Holder(U&& value_) noexcept + : Placeholder(typeid(T)), value(std::forward(value_)) {} + std::unique_ptr clone() const override { + return std::make_unique>(value); + } + T value; + }; + + /// The type erased object. + std::unique_ptr content_; +}; + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..3f381a63944f580a7c787e8142ac65b2be150729 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace nn { + +/// Wraps a function in a `Module`. +/// +/// The `Functional` module allows wrapping an arbitrary function or function +/// object in an `nn::Module`. This is primarily handy for usage in +/// `Sequential`. +/// +/// \rst +/// .. code-block:: cpp +/// +/// Sequential sequential( +/// Linear(3, 4), +/// Functional(torch::relu), +/// BatchNorm1d(3), +/// Functional(torch::elu, /*alpha=*/1)); +/// \endrst +/// +/// While a `Functional` module only accepts a single `Tensor` as input, it is +/// possible for the wrapped function to accept further arguments. However, +/// these have to be bound *at construction time*. For example, if +/// you want to wrap `torch::leaky_relu`, which accepts a `slope` scalar as its +/// second argument, with a particular value for its `slope` in a `Functional` +/// module, you could write +/// +/// \rst +/// .. code-block:: cpp +/// +/// Functional(torch::leaky_relu, /*slope=*/0.5) +/// \endrst +/// +/// The value of `0.5` is then stored within the `Functional` object and +/// supplied to the function call at invocation time. Note that such bound +/// values are evaluated eagerly and stored a single time. See the documentation +/// of [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind) +/// for more information on the semantics of argument binding. +/// +/// \rst +/// .. attention:: +/// After passing any bound arguments, the function must accept a single +/// tensor and return a single tensor. +/// \endrst +/// +/// Note that `Functional` overloads the call operator (`operator()`) such that +/// you can invoke it with `my_func(...)`. +class TORCH_API FunctionalImpl : public torch::nn::Cloneable { + public: + using Function = std::function; + + /// Constructs a `Functional` from a function object. + explicit FunctionalImpl(Function function); + + template < + typename SomeFunction, + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) > 0)>> + explicit FunctionalImpl(SomeFunction original_function, Args&&... args) + // NOLINTNEXTLINE(modernize-avoid-bind) + : function_(std::bind( + original_function, + /*input=*/std::placeholders::_1, + std::forward(args)...)) { + // std::bind is normally evil, but (1) gcc is broken w.r.t. handling + // parameter pack expansion in lambdas and (2) moving parameter packs into + // a lambda only works with C++14, so std::bind is the more move-aware + // solution here. + } + + void reset() override; + + /// Pretty prints the `Functional` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Forwards the `input` tensor to the underlying (bound) function object. + Tensor forward(Tensor input); + + /// Calls forward(input). + Tensor operator()(Tensor input); + + bool is_serializable() const override; + + private: + Function function_; +}; + +/// A `ModuleHolder` subclass for `FunctionalImpl`. +/// See the documentation for `FunctionalImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Functional); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h new file mode 100644 index 0000000000000000000000000000000000000000..b96b7611936f16e9bdae934160af263a4710f0a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h @@ -0,0 +1,262 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// An OrderedDict of `Module`s that registers its elements by their `key`s. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::OrderedDict> ordereddict = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict1(ordereddict); +/// +/// for (const auto &module : *dict1) { +/// module->pretty_print(std::cout); +/// } +/// +/// std::vector>> list = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict2(list); +/// +/// for (const auto &module : *dict2) { +/// module->pretty_print(std::cout); +/// } +/// +/// \endrst +/// +/// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`? +/// The value a `ModuleDict` provides over manually calling an ordered map of +/// modules is that it allows treating the whole container *as a single module*, +/// such that performing a transformation on the `ModuleDict` applies to each of +/// the modules it stores (which are each a registered submodule of the +/// `ModuleDict`). For example, calling `.to(torch::kCUDA)` on a `ModuleDict` +/// will move each module in the map to CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::OrderedDict> ordereddict = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict(ordereddict); +/// +/// // Convert all modules to CUDA. +/// dict->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `ModuleDict` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding new modules from a +/// vector of key-module pairs or an `OrderedDict` or another `ModuleDict` after +/// construction via `update`. +class ModuleDictImpl : public Cloneable { + public: + using Iterator = + torch::OrderedDict>::Iterator; + using ConstIterator = + torch::OrderedDict>::ConstIterator; + + ModuleDictImpl() = default; + + /// Constructs the `ModuleDict` from a list of string-Module pairs. + explicit ModuleDictImpl( + const std::vector>>& + modules) { + update(modules); + } + + /// Constructs the `ModuleDict` from an `OrderedDict`. + explicit ModuleDictImpl( + const torch::OrderedDict>& modules) { + update(modules); + } + + /// Return the items in the `ModuleDict`. + std::vector>> items() const { + return modules_.pairs(); + } + + /// Return the keys in the `ModuleDict`. + std::vector keys() const { + return modules_.keys(); + } + + /// Return the values in the `ModuleDict`. + std::vector> values() const { + return modules_.values(); + } + + /// Return an iterator to the start of `ModuleDict`. + Iterator begin() { + return modules_.begin(); + } + + /// Return a const iterator to the start of `ModuleDict`. + ConstIterator begin() const { + return modules_.begin(); + } + + /// Return an iterator to the end of `ModuleDict`. + Iterator end() { + return modules_.end(); + } + + /// Return a const iterator to the end of `ModuleDict`. + ConstIterator end() const { + return modules_.end(); + } + + /// Return the number of items currently stored in the `ModuleDict`. + size_t size() const noexcept { + return modules_.size(); + } + + /// Return true if the `ModuleDict` is empty, otherwise return false. + bool empty() const noexcept { + return modules_.is_empty(); + } + + /// Check if the centain parameter with the key in the `ModuleDict`. + bool contains(const std::string& key) const noexcept { + return modules_.contains(key); + } + + /// Remove all items from the `ModuleDict`. + void clear() { + // Not remove the registration of modules to make it consistent with python + // version. + modules_.clear(); + } + + /// Special cloning function for `ModuleDict` because it does not use + /// `reset()`. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->insert(module.key(), module.value()->clone(device)); + } + return clone; + } + + /// `reset()` is empty for `ModuleDict`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `ModuleDict` into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ModuleDict"; + } + + /// Attempts to returns the `Module` associated with the given `key`. Throws + /// an exception if no such `key` is stored in the `ModuleDict`. Check + /// contains(key) before for a non-throwing way of access. + std::shared_ptr operator[](const std::string& key) const { + return modules_[key]; + } + + /// Attempts to return the module at the given key as the requested type. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + template + T& at(const std::string& key) { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + auto module = modules_[key]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + key, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return the module at the given key as the requested type. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + template + const T& at(const std::string& key) const { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + const auto module = modules_[key]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + key, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Removes and returns the `Module` associated with the given `key`. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + std::shared_ptr pop(const std::string& key) { + auto module = modules_[key]; + modules_.erase(key); + // Not remove the registration of the module to make it consistent with + // python version. + return module; + } + + /// Updated the `ModuleDict` with a vector of key-module pairs. + void update( + const std::vector>>& + modules) { + for (auto& item : modules) { + insert(item.first, item.second); + } + } + + /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or + /// `ModuleDict`. + template + void update(const Container& container) { + for (auto& item : container) { + insert(item.key(), item.value()); + } + } + + private: + /// Private `OrderedDict` holding the key-Module pairs. + torch::OrderedDict> modules_; + + /// Insert a key-module pair by overwriting existing keys, + /// and register or replace the `Module`. + void insert(const std::string& key, std::shared_ptr module) { + if (contains(key)) { + modules_[key] = std::move(module); + replace_module(key, modules_[key]); + } else { + modules_.insert(key, std::move(module)); + register_module(key, modules_.back().value()); + } + } +}; + +/// A `ModuleHolder` subclass for `ModuleDictImpl`. +/// See the documentation for `ModuleDictImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ModuleDict); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h new file mode 100644 index 0000000000000000000000000000000000000000..b115abe1e9551852844b3da886d4b07f1d7e96c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h @@ -0,0 +1,274 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace nn { + +/// A list of `Module`s that registers its elements. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::ModuleList mlist( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// for (const auto &module : *mlist) { +/// module->pretty_print(std::cout); +/// } +/// +/// \endrst +/// +/// Why should you use `ModuleList` instead of a simple `std::vector`? The value +/// a `ModuleList` provides over manually calling a sequence of modules is that +/// it allows treating the whole container *as a single module*, such that +/// performing a transformation on the `ModuleList` applies to each of the +/// modules it stores (which are each a registered submodule of the +/// `ModuleList`). For example, calling +/// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to +/// CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::ModuleList mlist( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// // Convert all modules to CUDA. +/// mlist->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `ModuleList` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding a new module after +/// construction via `push_back`, as well as joining two `ModuleList`s via +/// `extend`. +class ModuleListImpl : public Cloneable { + public: + using Iterator = std::vector>::iterator; + using ConstIterator = std::vector>::const_iterator; + + ModuleListImpl() = default; + + /// Constructs the `ModuleList` from a variadic list of modules. + template + explicit ModuleListImpl(Modules&&... modules) { + modules_.reserve(sizeof...(Modules)); + push_back_var(std::forward(modules)...); + } + + /// Special cloning function for `ModuleList` because it does not use + /// `reset()`. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->push_back(module->clone(device)); + } + return clone; + } + + /// `reset()` is empty for `ModuleList`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `ModuleList` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ModuleList"; + } + + void push_back(std::shared_ptr module) { + modules_.push_back(std::move(module)); + const auto index = modules_.size() - 1; + register_module(std::to_string(index), modules_[index]); + } + + /// Adds a new `Module` to the `ModuleList` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void push_back(M&& module) { + using Type = typename std::remove_reference::type; + push_back(std::make_shared(std::forward(module))); + } + + /// Unwraps the contained module of a `ModuleHolder` and adds it to the + /// `ModuleList`. + template + void push_back(const ModuleHolder& module_holder) { + push_back(module_holder.ptr()); + } + + /// Iterates over the container and calls `push_back()` on each value. + template + void extend(const Container& container) { + for (const auto& module : container) { + push_back(module); + } + } + + /// Returns an iterator to the start of the `ModuleList`. + Iterator begin() { + return modules_.begin(); + } + + /// Returns a const iterator to the start of the `ModuleList`. + ConstIterator begin() const { + return modules_.begin(); + } + + /// Returns an iterator to the end of the `ModuleList`. + Iterator end() { + return modules_.end(); + } + + /// Returns a const iterator to the end of the `ModuleList`. + ConstIterator end() const { + return modules_.end(); + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + T& at(size_t index) { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + auto module = modules_[index]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + index, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + const T& at(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + const auto module = modules_[index]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + index, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the + /// underlying module at the given index. Throws an exception if the index is + /// out of bounds. + std::shared_ptr ptr(size_t index) const { + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index]; + } + + /// Attempts to return a `std::shared_ptr` whose type is the one provided. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + std::shared_ptr ptr(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::ptr with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return std::dynamic_pointer_cast(modules_[index]); + } + + /// Like `ptr(index)`. + std::shared_ptr operator[](size_t index) const { + // This is the only method we can call without a type. + return ptr(index); + } + + /// The current size of the `ModuleList` container. + size_t size() const noexcept { + return modules_.size(); + } + + /// True if there are no modules in the `ModuleList`. + bool is_empty() const noexcept { + return size() == 0; + } + + void insert(size_t index, std::shared_ptr module) { + TORCH_CHECK(index <= size(), "Index out of range"); + + if (index == size()) + push_back(std::move(module)); + else { + modules_.insert( + modules_.begin() + Iterator::difference_type(index), + std::move(module)); + + for (const auto i : c10::irange(index, size() - 1)) { + (void)i; // Suppress unused variable warning + replace_module(std::to_string(index), modules_[index]); + } + register_module(std::to_string(size() - 1), modules_.back()); + } + } + + /// Unwraps the contained module of a `ModuleHolder` and inserts it in the + /// `ModuleList`. + template + void insert(size_t index, const ModuleHolder& module_holder) { + insert(index, module_holder.ptr()); + } + + /// inserts a new `Module` to the `ModuleList` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void insert(size_t index, M&& module) { + using Type = typename std::remove_reference::type; + insert(index, std::make_shared(std::forward(module))); + } + + private: + template + void push_back_var(Head&& head, Tail&&... tail) { + push_back(std::forward(head)); + // Recursively calls this method, until the parameter pack only thas this + // entry left. Then calls `push_back()` a final time (above). + push_back_var(std::forward(tail)...); + } + + /// The base case, when the list of modules is empty. + void push_back_var() {} + + // Box the AnyModules to give ModuleList reference semantics, like the rest of + // the API. Note that this is not required otherwise, this could just be a + // `vector`. + std::vector> modules_; +}; + +/// A `ModuleHolder` subclass for `ModuleListImpl`. +/// See the documentation for `ModuleListImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ModuleList); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h new file mode 100644 index 0000000000000000000000000000000000000000..00d39de17f4012cbfb9aa4e56327d26c66f33bc2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Stores a type erased `Module` with name. +/// +/// The `NamedAnyModule` class enables the following API for constructing +/// `nn::Sequential` with named submodules: +/// \rst +/// .. code-block:: cpp +/// +/// struct M : torch::nn::Module { +/// explicit M(int value_) : value(value_) {} +/// int value; +/// int forward() { +/// return value; +/// } +/// }; +/// +/// Sequential sequential({ +/// {"m1", std::make_shared(1)}, // shared pointer to `Module` is +/// supported {std::string("m2"), M(2)}, // `Module` is supported +/// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported +/// }); +/// \endrst +class NamedAnyModule { + public: + /// Creates a `NamedAnyModule` from a (boxed) `Module`. + template + NamedAnyModule(std::string name, std::shared_ptr module_ptr) + : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {} + + /// Creates a `NamedAnyModule` from a `Module`, moving or copying it + /// into a `shared_ptr` internally. + // NOTE: We need to use `std::remove_reference::type` to get rid of + // any reference components for make_unique. + template > + NamedAnyModule(std::string name, M&& module) + : NamedAnyModule( + std::move(name), + std::make_shared::type>( + std::forward(module))) {} + + /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from + /// a `ModuleHolder`. + template + NamedAnyModule(std::string name, const ModuleHolder& module_holder) + : NamedAnyModule(std::move(name), module_holder.ptr()) {} + + /// Creates a `NamedAnyModule` from a type-erased `AnyModule`. + NamedAnyModule(std::string name, AnyModule any_module) + : name_(std::move(name)), module_(std::move(any_module)) {} + + /// Returns a reference to the name. + const std::string& name() const noexcept { + return name_; + } + + /// Returns a reference to the module. + AnyModule& module() noexcept { + return module_; + } + + /// Returns a const reference to the module. + const AnyModule& module() const noexcept { + return module_; + } + + private: + std::string name_; + AnyModule module_; +}; + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h new file mode 100644 index 0000000000000000000000000000000000000000..f201825deb5bad0bc8640b6d977e156d10a74435 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h @@ -0,0 +1,148 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +class ParameterDictImpl : public Cloneable { + public: + using Iterator = OrderedDict::Iterator; + using ConstIterator = OrderedDict::ConstIterator; + + ParameterDictImpl() = default; + + explicit ParameterDictImpl( + const torch::OrderedDict& params) { + parameters_ = params; + } + + /// `reset()` is empty for `ParameterDict`, since it does not have + /// parameters of its own. + void reset() override {} + + /// Pretty prints the `ParameterDict` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ParameterDict(" << std::endl; + for (const auto& pair : parameters_) { + stream << "(" << pair.key() << ")" + << ": Parameter containing: [" << pair.value().scalar_type() + << " of size " << pair.value().sizes() << "]"; + ; + stream << std::endl; + } + stream << ")"; + } + + /// Insert the parameter along with the key into ParameterDict + /// The parameter is set to be require grad by default + Tensor& insert(std::string key, Tensor param) { + bool requires_grad = param.requires_grad(); + return register_parameter(std::move(key), std::move(param), requires_grad); + } + + /// Remove key from the ParameterDict and return its value, throw exception + /// if the key is not contained. Please check contains(key) before for a + /// non-throwing access. + Tensor pop(const std::string& key) { + torch::Tensor v = parameters_[key]; + parameters_.erase(key); + return v; + } + + /// Return the keys in the dict + ::std::vector keys() const { + return parameters_.keys(); + } + + /// Return the Values in the dict + ::std::vector values() const { + return parameters_.values(); + } + + /// Return an iterator to the start of ParameterDict + Iterator begin() { + return parameters_.begin(); + } + + /// Return a const iterator to the start of ParameterDict + ConstIterator begin() const { + return parameters_.begin(); + } + + /// Return an iterator to the end of ParameterDict + Iterator end() { + return parameters_.end(); + } + + /// Return a const iterator to the end of ParameterDict + ConstIterator end() const { + return parameters_.end(); + } + + /// Return the number of items currently stored in the ParameterDict + size_t size() const noexcept { + return parameters_.size(); + } + + /// Return true if the ParameterDict is empty, otherwise return false + bool empty() const noexcept { + return parameters_.is_empty(); + } + + /// Update the ParameterDict with the key-value pairs from + /// another ParameterDict, overwriting existing key + template + void update(const Container& container) { + for (auto& item : container) { + parameters_[item.key()] = item.value(); + } + } + + /// Remove all parameters in the ParameterDict + void clear() { + parameters_.clear(); + } + + /// Check if the centain parameter with the key in the ParameterDict + bool contains(const std::string& key) const noexcept { + return parameters_.contains(key); + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + const Tensor& get(const std::string& key) const { + return parameters_[key]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + Tensor& get(const std::string& key) { + return parameters_[key]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + Tensor& operator[](const std::string& key) { + return parameters_[key]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + const Tensor& operator[](const std::string& key) const { + return parameters_[key]; + } +}; + +TORCH_MODULE(ParameterDict); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h new file mode 100644 index 0000000000000000000000000000000000000000..cb816d1bb2a1e6493665a906036602eaac03170f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h @@ -0,0 +1,169 @@ +#pragma once + +#include +#include + +#include + +namespace torch { +namespace nn { +class ParameterListImpl : public Cloneable { + public: + using Iterator = typename std::vector< + OrderedDict::Item>::iterator; + using ConstIterator = typename std::vector< + OrderedDict::Item>::const_iterator; + + ParameterListImpl() = default; + + /// Constructs the `ParameterList` from a variadic list of ParameterList. + template + explicit ParameterListImpl(Tensors&&... params) { + parameters_.reserve(sizeof...(Tensors)); + push_back_var(std::forward(params)...); + } + + template + explicit ParameterListImpl(const Tensors&... params) { + parameters_.reserve(sizeof...(Tensors)); + push_back_var(std::forward(params)...); + } + + /// `reset()` is empty for `ParameterList`, since it does not have parameters + /// of its own. + void reset() override {} + + /// Pretty prints the `ParameterList` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ParameterList(" << std::endl; + for (const auto& pair : parameters_) { + stream << "(" << pair.key() << ")" + << ": Parameter containing: [" << pair.value().scalar_type() + << " of size " << pair.value().sizes() << "]"; + ; + stream << std::endl; + } + stream << ")"; + } + + /// push the a given parameter at the end of the list + void append(torch::Tensor&& param) { + bool requires_grad = param.requires_grad(); + register_parameter( + std::to_string(parameters_.size()), std::move(param), requires_grad); + } + + /// push the a given parameter at the end of the list + void append(const torch::Tensor& param) { + bool requires_grad = param.requires_grad(); + register_parameter( + std::to_string(parameters_.size()), param, requires_grad); + } + + /// push the a given parameter at the end of the list + /// And the key of the pair will be discarded, only the value + /// will be added into the `ParameterList` + void append(const OrderedDict::Item& pair) { + register_parameter( + std::to_string(parameters_.size()), + pair.value(), + pair.value().requires_grad()); + } + + /// extend parameters from a container to the end of the list + template + void extend(const Container& container) { + for (const auto& param : container) { + append(param); + } + } + + /// Returns an iterator to the start of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + Iterator begin() { + return parameters_.begin(); + } + + /// Returns a const iterator to the start of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + ConstIterator begin() const { + return parameters_.begin(); + } + + /// Returns an iterator to the end of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + Iterator end() { + return parameters_.end(); + } + + /// Returns a const iterator to the end of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + ConstIterator end() const { + return parameters_.end(); + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + at::Tensor& at(size_t idx) { + TORCH_CHECK(idx < size(), "Index out of range"); + return parameters_[std::to_string(idx)]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + const at::Tensor& at(size_t idx) const { + TORCH_CHECK(idx < size(), "Index out of range"); + return parameters_[std::to_string(idx)]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + at::Tensor& operator[](size_t idx) { + return at(idx); + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + const at::Tensor& operator[](size_t idx) const { + return at(idx); + } + + /// Return the size of the ParameterList + size_t size() const noexcept { + return parameters_.size(); + } + /// True if the ParameterList is empty + bool is_empty() const noexcept { + return parameters_.is_empty(); + } + + /// Overload the +=, so that two ParameterList could be incrementally added + template + Container& operator+=(const Container& other) { + extend(other); + return *this; + } + + private: + template + void push_back_var(Head&& head, Tail&&... tail) { + append(std::forward(head)); + // Recursively calls this method, until the parameter pack only thas this + // entry left. Then calls `push_back()` a final time (above). + push_back_var(std::forward(tail)...); + } + + /// The base case, when the list of modules is empty. + void push_back_var() {} +}; +TORCH_MODULE(ParameterList); +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h new file mode 100644 index 0000000000000000000000000000000000000000..6ee12bc477d8293ba6095c3e3b38778c99ce4f43 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -0,0 +1,388 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// A list of `Module`s that acts as a `Module` itself. +/// +/// A `Sequential` is fundamentally a list of `Module`s, each with a `forward()` +/// method. `Sequential` provides a `forward()` method of its own, which accepts +/// any input and forwards it to the first module it stores. It then "chains" +/// outputs to inputs sequentially for each subsequent module, finally returning +/// the output of the last module. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::Sequential seq( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// auto output = seq->forward(torch::ones(3)); +/// +/// \endrst +/// +/// This can conceptually be thought of as the following loop (using Python as +/// pseudocode): +/// +/// \rst +/// .. code-block:: python +/// +/// def forward(sequential, input): +/// for module in sequential: +/// input = module(input) +/// return input +/// +/// \endrst +/// +/// Why should you use `Sequential` instead of a simple `std::vector`? The value +/// a `Sequential` provides over manually calling a sequence of modules is that +/// it allows treating the whole container *as a single module*, such that +/// performing a transformation on the `Sequential` applies to each of the +/// modules it stores (which are each a registered submodule of the +/// `Sequential`). For example, calling +/// `.to(torch::kCUDA)` on a `Sequential` will move each module in the list to +/// CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::Sequential seq( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// // Convert all modules to CUDA. +/// seq->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `Sequential` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding a new module after +/// construction via `push_back`, as well as joining two `Sequential`s via +/// `extend`. +/// +/// \rst +/// .. attention:: +/// One current limitation of `Sequential` is that all except the first module +/// must accept a single argument. If your modules need to take multiple +/// arguments, you should define them to take and return tuples. +/// \endrst +class SequentialImpl : public Cloneable { + public: + using Iterator = std::vector::iterator; + using ConstIterator = std::vector::const_iterator; + + SequentialImpl() = default; + + /// Constructs the `Sequential` from a variadic list of modules. + template + explicit SequentialImpl(Modules&&... modules) { + modules_.reserve(sizeof...(Modules)); + push_back(std::forward(modules)...); + } + + /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s. + explicit SequentialImpl( + torch::OrderedDict&& ordered_dict) { + modules_.reserve(ordered_dict.size()); + for (auto& item : ordered_dict) { + push_back(item.key(), std::move(item.value())); + } + } + + /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s. + /// It enables the following use case: + /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})` + explicit SequentialImpl(std::initializer_list named_modules) { + modules_.reserve(named_modules.size()); + for (const auto& named_module : named_modules) { + push_back(named_module.name(), named_module.module()); + } + } + + /// Special cloning function for `Sequential` because it does not use + /// `reset()`. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->push_back(module.clone(device)); + } + return clone; + } + + /// `reset()` is empty for `Sequential`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `Sequential` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::Sequential"; + } + + /// Feeds `inputs` to the first module and then chains outputs to inputs, + /// returning the last output. + /// + /// Conceptually the following loop in Python: + /// + /// \rst + /// .. code-block:: python + /// + /// def forward(sequential, input): + /// for module in sequential: + /// input = module(input) + /// return input + /// + /// \endrst + /// + /// The return type is taken as the first template parameter. It defaults to + /// `Tensor`. If the last module in the `Sequential` returns another type `T`, + /// you should call `forward(inputs)` instead of just `forward(inputs)`: + /// + /// \rst + /// .. code-block:: cpp + /// + /// torch::Tensor tensor = sequential1->forward(inputs); + /// int integer = sequential2->forward(inputs); + /// float value = sequential3->forward(inputs); + /// + /// \endrst + template + ReturnType forward(InputTypes&&... inputs) { + TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential"); + + auto iterator = modules_.begin(); + auto input = iterator->any_forward(std::forward(inputs)...); + + for (++iterator; iterator != modules_.end(); ++iterator) { + input = iterator->any_forward(std::move(input)); + } + + // Check the return value and give a nice error message if the requested + // return type was incorrect. + if (auto* return_value = input.template try_get()) { + return std::move(*return_value); + } + AT_ERROR( + "The type of the return value is ", + c10::demangle(input.type_info().name()), + ", but you asked for type ", + c10::demangle(typeid(ReturnType).name())); + } + + /// Adds a new (boxed) `Module` to the `Sequential` container. + template + void push_back(std::shared_ptr module_ptr) { + push_back(std::to_string(modules_.size()), std::move(module_ptr)); + } + + /// Adds a new named (boxed) `Module` to the `Sequential` container. + template + void push_back(std::string name, std::shared_ptr module_ptr) { + push_back(std::move(name), AnyModule(std::move(module_ptr))); + } + + /// Adds a new `Module` to the `Sequential` container, moving or copying it + /// into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. This means you can write + /// `Sequential(Module(3, 4))` instead of + /// `Sequential(std::make_shared(3, 4))`. + template > + void push_back(M&& module) { + push_back(std::to_string(modules_.size()), std::forward(module)); + } + + /// Adds a new named `Module` to the `Sequential` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void push_back(std::string name, M&& module) { + using Type = typename std::remove_reference_t; + push_back(std::move(name), std::make_shared(std::forward(module))); + } + + /// Unwraps the contained module of a `ModuleHolder` and adds it to the + /// `Sequential`. + template + void push_back(const ModuleHolder& module_holder) { + push_back(std::to_string(modules_.size()), module_holder); + } + + /// Unwraps the contained named module of a `ModuleHolder` and adds it to the + /// `Sequential`. + template + void push_back(std::string name, const ModuleHolder& module_holder) { + push_back(std::move(name), module_holder.ptr()); + } + + /// Iterates over the container and calls `push_back()` on each value. + template + void extend(const Container& container) { + for (const auto& module : container) { + push_back(module); + } + } + + /// Adds a type-erased `AnyModule` to the `Sequential`. + void push_back(AnyModule any_module) { + push_back(std::to_string(modules_.size()), std::move(any_module)); + } + + void push_back(std::string name, AnyModule any_module) { + modules_.push_back(std::move(any_module)); + const auto index = modules_.size() - 1; + register_module(std::move(name), modules_[index].ptr()); + } + + /// Returns an iterator to the start of the `Sequential`. + Iterator begin() { + return modules_.begin(); + } + + /// Returns a const iterator to the start of the `Sequential`. + ConstIterator begin() const { + return modules_.begin(); + } + + /// Returns an iterator to the end of the `Sequential`. + Iterator end() { + return modules_.end(); + } + + /// Returns a const iterator to the end of the `Sequential`. + ConstIterator end() const { + return modules_.end(); + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + T& at(size_t index) { + static_assert( + torch::detail::is_module::value, + "Can only call Sequential::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].get(); + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + const T& at(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call Sequential::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].get(); + } + + /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the + /// underlying module at the given index. Throws an exception if the index is + /// out of bounds. + std::shared_ptr ptr(size_t index) const { + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].ptr(); + } + + /// Attempts to return a `std::shared_ptr` whose type is the one provided. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + std::shared_ptr ptr(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call Sequential::ptr with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].ptr(); + } + + /// Like `ptr(index)`. + std::shared_ptr operator[](size_t index) const { + // This is the only method we can call without a type. + return ptr(index); + } + + /// The current size of the `Sequential` container. + size_t size() const noexcept { + return modules_.size(); + } + + /// True if there are no modules in the `Sequential`. + bool is_empty() const noexcept { + return size() == 0; + } + + private: + /// Takes a First *and* Second parameter, to avoid ambiguity when a parameter + /// pack has only one type, in which case the template would be preferred, + /// even if the other `push_back` functions are better fits (e.g. `unique_ptr` + /// -> `shared_ptr` overload). + /// NOTE: We explicitly avoid matching this template with + /// `push_back(std::string("name"), module)` or `push_back("name", module)`, + /// since they should be handled by their respective `push_back` functions. + template < + typename First, + typename Second, + typename... Rest, + typename = std::enable_if_t< + !std::is_same_v && + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + !std::is_same_v, std::decay_t>>> + void push_back(First&& first, Second&& second, Rest&&... rest) { + push_back(std::forward(first)); + // Recursively calls this method, until the parameter pack only thas this + // entry left. Then calls `push_back()` a final time (above). + push_back(std::forward(second), std::forward(rest)...); + } + + /// The base case, when the list of modules is empty. + void push_back() {} + + // Box the AnyModules to give Sequential reference semantics, like the rest of + // the API. Note that this is not required otherwise, this could just be a + // `vector`. + std::vector modules_; +}; + +/// A `ModuleHolder` subclass for `SequentialImpl`. +/// See the documentation for `SequentialImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +class Sequential : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + + Sequential() : ModuleHolder() {} + + /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s. + /// It enables the following use case: + /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})` + Sequential(std::initializer_list named_modules) + : ModuleHolder(std::make_shared(named_modules)) {} +}; +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h new file mode 100644 index 0000000000000000000000000000000000000000..a2ebabded6fabba8a105eacbdbb61e0377f98e26 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h @@ -0,0 +1,190 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch { +namespace nn { + +namespace detail { + +template +class _DropoutNd : public torch::nn::Cloneable { + public: + _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){}; + + explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) { + // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) + reset(); + } + + void reset() override { + TORCH_CHECK( + options.p() >= 0. && options.p() <= 1., + "dropout probability has to be between 0 and 1, but got ", + options.p()); + } + + /// The options with which this `Module` was constructed. + DropoutOptions options; +}; + +} // namespace detail + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies dropout over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::DropoutOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Dropout model(DropoutOptions().p(0.42).inplace(true)); +/// ``` +class TORCH_API DropoutImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(Tensor input); + + /// Pretty prints the `Dropout` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `DropoutImpl`. +/// See the documentation for `DropoutImpl` class to learn what methods it +/// provides, and examples of how to use `Dropout` with +/// `torch::nn::DropoutOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Dropout); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies dropout over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::Dropout2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true)); +/// ``` +class TORCH_API Dropout2dImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(Tensor input); + + /// Pretty prints the `Dropout2d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `Dropout2dImpl`. +/// See the documentation for `Dropout2dImpl` class to learn what methods it +/// provides, and examples of how to use `Dropout2d` with +/// `torch::nn::Dropout2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Dropout2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies dropout over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::Dropout3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true)); +/// ``` +class TORCH_API Dropout3dImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(Tensor input); + + /// Pretty prints the `Dropout3d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `Dropout3dImpl`. +/// See the documentation for `Dropout3dImpl` class to learn what methods it +/// provides, and examples of how to use `Dropout3d` with +/// `torch::nn::Dropout3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Dropout3d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Alpha Dropout over the input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AlphaDropout to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true)); +/// ``` +class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `AlphaDropout` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `AlphaDropoutImpl`. +/// See the documentation for `AlphaDropoutImpl` class to learn what methods it +/// provides, and examples of how to use `AlphaDropout` with +/// `torch::nn::AlphaDropoutOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(AlphaDropout); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true)); +/// ``` +class TORCH_API FeatureAlphaDropoutImpl + : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `FeatureAlphaDropout` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `FeatureAlphaDropoutImpl`. +/// See the documentation for `FeatureAlphaDropoutImpl` class to learn what +/// methods it provides, and examples of how to use `FeatureAlphaDropout` with +/// `torch::nn::FeatureAlphaDropoutOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(FeatureAlphaDropout); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h new file mode 100644 index 0000000000000000000000000000000000000000..6b415a99b5ea8ec94d3ea24a84d21ffedd8f05c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Applies fold over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Fold to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::FoldOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, +/// 1}).stride(2)); +/// ``` +class TORCH_API FoldImpl : public torch::nn::Cloneable { + public: + FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) + : FoldImpl(FoldOptions(output_size, kernel_size)) {} + explicit FoldImpl(const FoldOptions& options_); + + void reset() override; + + /// Pretty prints the `Fold` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this `Module` was constructed. + FoldOptions options; +}; + +/// A `ModuleHolder` subclass for `FoldImpl`. +/// See the documentation for `FoldImpl` class to learn what methods it +/// provides, and examples of how to use `Fold` with `torch::nn::FoldOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Fold); + +// ============================================================================ + +/// Applies unfold over a 4-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Unfold to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::UnfoldOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2)); +/// ``` +class TORCH_API UnfoldImpl : public Cloneable { + public: + UnfoldImpl(ExpandingArray<2> kernel_size) + : UnfoldImpl(UnfoldOptions(kernel_size)) {} + explicit UnfoldImpl(const UnfoldOptions& options_); + + void reset() override; + + /// Pretty prints the `Unfold` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this `Module` was constructed. + UnfoldOptions options; +}; + +/// A `ModuleHolder` subclass for `UnfoldImpl`. +/// See the documentation for `UnfoldImpl` class to learn what methods it +/// provides, and examples of how to use `Unfold` with +/// `torch::nn::UnfoldOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Unfold); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h new file mode 100644 index 0000000000000000000000000000000000000000..66ebb6e7390a958de242d5ddd25e1acce307bd8c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h @@ -0,0 +1,153 @@ +#pragma once + +#include +#include + +namespace torch { +namespace nn { + +/// Base class for all (dimension-specialized) instance norm modules +template +class InstanceNormImpl + : public torch::nn::NormImplBase { + private: + inline Tensor apply_instance_norm(const Tensor& input) { + return torch::nn::functional::detail::instance_norm( + input, + this->running_mean, + this->running_var, + this->weight, + this->bias, + this->is_training() || !this->options.track_running_stats(), + this->options.momentum(), + this->options.eps()); + } + + inline Tensor handle_no_batch_input(const Tensor& input) { + return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0); + } + + public: + using torch::nn::NormImplBase::NormImplBase; + + Tensor forward(const Tensor& input) { + this->_check_input_dim(input); + + // For InstanceNorm1D, 2D is unbatched and 3D is batched + // For InstanceNorm2D, 3D is unbatched and 4D is batched + // For InstanceNorm3D, 4D is unbatched and 5D is batched + // check if input does not have a batch-dim + if (input.dim() == D + 1) { + return this->handle_no_batch_input(input); + } + + return this->apply_instance_norm(input); + } + + /// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d(" + << this->options.num_features() << ", " + << "eps=" << this->options.eps() << ", " + << "momentum=" << this->options.momentum() << ", " + << "affine=" << this->options.affine() << ", " + << "track_running_stats=" << this->options.track_running_stats() + << ")"; + } +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the InstanceNorm1d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::InstanceNorm1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// InstanceNorm1d +/// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API InstanceNorm1dImpl + : public InstanceNormImpl<1, InstanceNorm1dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl; +}; + +/// A `ModuleHolder` subclass for `InstanceNorm1dImpl`. +/// See the documentation for `InstanceNorm1dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm1d` with +/// `torch::nn::InstanceNorm1dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(InstanceNorm1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the InstanceNorm2d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// InstanceNorm2d +/// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API InstanceNorm2dImpl + : public InstanceNormImpl<2, InstanceNorm2dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl; +}; + +/// A `ModuleHolder` subclass for `InstanceNorm2dImpl`. +/// See the documentation for `InstanceNorm2dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm2d` with +/// `torch::nn::InstanceNorm2dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(InstanceNorm2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the InstanceNorm3d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// InstanceNorm3d +/// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API InstanceNorm3dImpl + : public InstanceNormImpl<3, InstanceNorm3dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl; +}; + +/// A `ModuleHolder` subclass for `InstanceNorm3dImpl`. +/// See the documentation for `InstanceNorm3dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm3d` with +/// `torch::nn::InstanceNorm3dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(InstanceNorm3d); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h new file mode 100644 index 0000000000000000000000000000000000000000..747b548b758441ec7a86ae7a322e38c11f23f9ea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h @@ -0,0 +1,805 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch { +namespace nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ L1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the mean absolute error (MAE) between each +/// element in the input : math :`x` and target : `y`. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.L1Loss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::L1LossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// L1Loss model(L1LossOptions(torch::kNone)); +/// ``` +struct TORCH_API L1LossImpl : Cloneable { + explicit L1LossImpl(L1LossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `L1Loss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + L1LossOptions options; +}; + +/// A `ModuleHolder` subclass for `L1LossImpl`. +/// See the documentation for `L1LossImpl` class to learn what methods it +/// provides, and examples of how to use `L1Loss` with +/// `torch::nn::L1LossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(L1Loss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The Kullback-Leibler divergence loss measure +/// See https://pytorch.org/docs/main/nn.html#torch.nn.KLDivLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::KLDivLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// KLDivLoss model(KLDivLossOptions().reduction(torch::kNone)); +/// ``` +struct TORCH_API KLDivLossImpl : Cloneable { + explicit KLDivLossImpl(KLDivLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `KLDivLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + KLDivLossOptions options; +}; + +/// A `ModuleHolder` subclass for `KLDivLossImpl`. +/// See the documentation for `KLDivLossImpl` class to learn what methods it +/// provides, and examples of how to use `KLDivLoss` with +/// `torch::nn::KLDivLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(KLDivLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MSELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the mean squared error (squared L2 norm) +/// between each element in the input :math:`x` and target :math:`y`. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MSELoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MSELossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MSELoss model(MSELossOptions(torch::kNone)); +/// ``` +struct TORCH_API MSELossImpl : Cloneable { + explicit MSELossImpl(MSELossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `MSELoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MSELossOptions options; +}; + +/// A `ModuleHolder` subclass for `MSELossImpl`. +/// See the documentation for `MSELossImpl` class to learn what methods it +/// provides, and examples of how to use `MSELoss` with +/// `torch::nn::MSELossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MSELoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the Binary Cross Entropy +/// between the target and the output. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.BCELoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BCELossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API BCELossImpl : Cloneable { + explicit BCELossImpl(BCELossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `BCELoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + BCELossOptions options; +}; + +/// A `ModuleHolder` subclass for `BCELossImpl`. +/// See the documentation for `BCELossImpl` class to learn what methods it +/// provides, and examples of how to use `BCELoss` with +/// `torch::nn::BCELossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(BCELoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the loss given an input tensor :math:`x` +/// and a labels tensor :math:`y` (containing 1 or -1). +/// See https://pytorch.org/docs/main/nn.html#torch.nn.HingeEmbeddingLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// HingeEmbeddingLoss +/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone)); +/// ``` +struct TORCH_API HingeEmbeddingLossImpl : Cloneable { + explicit HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `HingeEmbeddingLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + HingeEmbeddingLossOptions options; +}; + +/// A `ModuleHolder` subclass for `HingeEmbeddingLossImpl`. +/// See the documentation for `HingeEmbeddingLossImpl` class to learn what +/// methods it provides, and examples of how to use `HingeEmbeddingLoss` with +/// `torch::nn::HingeEmbeddingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(HingeEmbeddingLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a multi-class classification hinge +/// loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) +/// and output :math:`y` (which is a 1D tensor of target class indices, :math:`0 +/// \leq y \leq \text{x.size}(1)-1`). See +/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiMarginLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight)); +/// ``` +struct TORCH_API MultiMarginLossImpl : public Cloneable { + explicit MultiMarginLossImpl(MultiMarginLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `MultiMarginLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MultiMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MultiMarginLossImpl`. +/// See the documentation for `MultiMarginLossImpl` class to learn what methods +/// it provides, and examples of how to use `MultiMarginLoss` with +/// `torch::nn::MultiMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MultiMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the loss given input tensors +/// `input1`, `input2`, and a `Tensor` label `target` with values 1 or +/// -1. This is used for measuring whether two inputs are similar or +/// dissimilar, using the cosine distance, and is typically used for learning +/// nonlinear embeddings or semi-supervised learning. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.CosineEmbeddingLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5)); +/// ``` +struct TORCH_API CosineEmbeddingLossImpl + : public Cloneable { + explicit CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `CosineEmbeddingLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& input1, + const Tensor& input2, + const Tensor& target); + + /// The options with which this `Module` was constructed. + CosineEmbeddingLossOptions options; +}; + +/// A `ModuleHolder` subclass for `CosineEmbeddingLossImpl`. +/// See the documentation for `CosineEmbeddingLossImpl` class to learn what +/// methods it provides, and examples of how to use `CosineEmbeddingLoss` with +/// `torch::nn::CosineEmbeddingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(CosineEmbeddingLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SmoothL1Loss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that uses a squared term if the absolute +/// element-wise error falls below beta and an L1 term otherwise. +/// It is less sensitive to outliers than the `MSELoss` and in some cases +/// prevents exploding gradients (e.g. see the paper `Fast R-CNN` by Ross +/// Girshick). See https://pytorch.org/docs/main/nn.html#torch.nn.SmoothL1Loss +/// to learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SmoothL1LossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5)); +/// ``` +struct TORCH_API SmoothL1LossImpl : public Cloneable { + explicit SmoothL1LossImpl(SmoothL1LossOptions options = {}); + + void reset() override; + + /// Pretty prints the `L1Loss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + SmoothL1LossOptions options; +}; + +/// A `ModuleHolder` subclass for `SmoothL1LossImpl`. +/// See the documentation for `SmoothL1LossImpl` class to learn what methods it +/// provides, and examples of how to use `SmoothL1Loss` with +/// `torch::nn::SmoothL1LossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(SmoothL1Loss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that uses a squared term if the absolute +/// element-wise error falls below delta and a delta-scaled L1 term otherwise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.HuberLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::HuberLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5)); +/// ``` +struct TORCH_API HuberLossImpl : public Cloneable { + explicit HuberLossImpl(HuberLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `HuberLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + HuberLossOptions options; +}; + +/// A `ModuleHolder` subclass for `HuberLossImpl`. +/// See the documentation for `HuberLossImpl` class to learn what methods it +/// provides, and examples of how to use `HuberLoss` with +/// `torch::nn::HuberLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(HuberLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a multi-class multi-classification +/// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch +/// `Tensor`) and output :math:`y` (which is a 2D `Tensor` of target class +/// indices). See +/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelMarginLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone)); +/// ``` +struct TORCH_API MultiLabelMarginLossImpl + : public Cloneable { + explicit MultiLabelMarginLossImpl(MultiLabelMarginLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `L1Loss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MultiLabelMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MultiLabelMarginLossImpl`. +/// See the documentation for `MultiLabelMarginLossImpl` class to learn what +/// methods it provides, and examples of how to use `MultiLabelMarginLoss` with +/// `torch::nn::MultiLabelMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MultiLabelMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SoftMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a two-class classification +/// logistic loss between input tensor :math:`x` and target tensor :math:`y` +/// (containing 1 or -1). +/// See https://pytorch.org/docs/main/nn.html#torch.nn.SoftMarginLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SoftMarginLossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone)); +/// ``` +struct TORCH_API SoftMarginLossImpl : public Cloneable { + explicit SoftMarginLossImpl(SoftMarginLossOptions options_ = {}); + + /// Pretty prints the `SoftMarginLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + void reset() override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + SoftMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `SoftMarginLossImpl`. +/// See the documentation for `SoftMarginLossImpl` class to learn what methods +/// it provides, and examples of how to use `SoftMarginLoss` with +/// `torch::nn::SoftMarginLossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(SoftMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a multi-label one-versus-all +/// loss based on max-entropy, between input :math:`x` and target :math:`y` of +/// size :math:`(N, C)`. See +/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelSoftMarginLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class +/// to learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MultiLabelSoftMarginLoss +/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API MultiLabelSoftMarginLossImpl + : public Cloneable { + explicit MultiLabelSoftMarginLossImpl( + MultiLabelSoftMarginLossOptions options_ = {}); + + /// Pretty prints the `MultiLabelSoftMarginLoss` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override; + + void reset() override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MultiLabelSoftMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MultiLabelSoftMarginLossImpl`. +/// See the documentation for `MultiLabelSoftMarginLossImpl` class to learn what +/// methods it provides, and examples of how to use `MultiLabelSoftMarginLoss` +/// with `torch::nn::MultiLabelSoftMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MultiLabelSoftMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the triplet loss given an input +/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater +/// than :math:`0`. This is used for measuring a relative similarity between +/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`, +/// `positive examples` and `negative examples` respectively). The +/// shapes of all input tensors should be :math:`(N, D)`. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::TripletMarginLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// TripletMarginLoss +/// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false)); +/// ``` +struct TORCH_API TripletMarginLossImpl + : public Cloneable { + explicit TripletMarginLossImpl(TripletMarginLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `TripletMarginLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative); + + /// The options with which this `Module` was constructed. + TripletMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `TripletMarginLossImpl`. +/// See the documentation for `TripletMarginLossImpl` class to learn what +/// methods it provides, and examples of how to use `TripletMarginLoss` with +/// `torch::nn::TripletMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(TripletMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the triplet loss given input +/// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, +/// positive, and negative examples, respectively); and a nonnegative, +/// real-valued function +/// ("distance function") used to compute the relationships between the anchor +/// and positive example ("positive distance") and the anchor and negative +/// example ("negative distance"). +/// See +/// https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginWithDistanceLoss +/// to learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` +/// class to learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// TripletMarginWithDistanceLoss +/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); +/// ``` +struct TORCH_API TripletMarginWithDistanceLossImpl + : public Cloneable { + explicit TripletMarginWithDistanceLossImpl( + TripletMarginWithDistanceLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `TripletMarginWithDistanceLoss` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative); + + /// The options with which this `Module` was constructed. + TripletMarginWithDistanceLossOptions options; +}; + +/// A `ModuleHolder` subclass for `TripletMarginWithDistanceLossImpl`. +/// See the documentation for `TripletMarginWithDistanceLossImpl` class to learn +/// what methods it provides, and examples of how to use +/// `TripletMarginWithDistanceLoss` with +/// `torch::nn::TripletMarginWithDistanceLossOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(TripletMarginWithDistanceLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The Connectionist Temporal Classification loss. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.CTCLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CTCLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CTCLoss +/// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum)); +/// ``` +struct TORCH_API CTCLossImpl : public Cloneable { + explicit CTCLossImpl(CTCLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `CTCLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths); + + /// The options with which this `Module` was constructed. + CTCLossOptions options; +}; + +/// A `ModuleHolder` subclass for `CTCLossImpl`. +/// See the documentation for `CTCLossImpl` class to learn what methods it +/// provides, and examples of how to use `CTCLoss` with +/// `torch::nn::CTCLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(CTCLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Negative log likelihood loss with Poisson distribution of target. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.PoissonNLLLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// PoissonNLLLoss +/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum)); +/// ``` +struct TORCH_API PoissonNLLLossImpl : public Cloneable { + explicit PoissonNLLLossImpl(PoissonNLLLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `PoissonNLLLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& log_input, const Tensor& targets); + + /// The options with which this `Module` was constructed. + PoissonNLLLossOptions options; +}; + +/// A `ModuleHolder` subclass for `PoissonNLLLossImpl`. +/// See the documentation for `PoissonNLLLossImpl` class to learn what methods +/// it provides, and examples of how to use `PoissonNLLLoss` with +/// `torch::nn::PoissonNLLLossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(PoissonNLLLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the loss given +/// inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`, +/// and a label 1D mini-batch tensor :math:`y` (containing 1 or -1). +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MarginRankingLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MarginRankingLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MarginRankingLoss +/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)); +/// ``` +struct TORCH_API MarginRankingLossImpl + : public Cloneable { + explicit MarginRankingLossImpl(MarginRankingLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `MarginRankingLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& input1, + const Tensor& input2, + const Tensor& targets); + + /// The options with which this `Module` was constructed. + MarginRankingLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MarginRankingLossImpl`. +/// See the documentation for `MarginRankingLossImpl` class to learn what +/// methods it provides, and examples of how to use `MarginRankingLoss` with +/// `torch::nn::MarginRankingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MarginRankingLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The negative log likelihood loss. It is useful to train a classification +/// problem with `C` classes. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::NLLLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +struct TORCH_API NLLLossImpl : public Cloneable { + explicit NLLLossImpl(NLLLossOptions options_ = {}); + + /// Pretty prints the `NLLLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + void reset() override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + NLLLossOptions options; + + /// A manual rescaling weight given to to each class. + Tensor weight; +}; + +/// A `ModuleHolder` subclass for `NLLLossImpl`. +/// See the documentation for `NLLLossImpl` class to learn what methods it +/// provides, and examples of how to use `NLLLoss` with +/// `torch::nn::NLLLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(NLLLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that computes cross entropy loss between input and +/// target. See +/// https://pytorch.org/docs/main/nn.html#torch.nn.CrossEntropyLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CrossEntropyLoss +/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +struct TORCH_API CrossEntropyLossImpl : public Cloneable { + explicit CrossEntropyLossImpl(CrossEntropyLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `CrossEntropyLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + CrossEntropyLossOptions options; + + /// A manual rescaling weight given to to each class. + Tensor weight; +}; + +/// A `ModuleHolder` subclass for `CrossEntropyLossImpl`. +/// See the documentation for `CrossEntropyLossImpl` class to learn what methods +/// it provides, and examples of how to use `CrossEntropyLoss` with +/// `torch::nn::CrossEntropyLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(CrossEntropyLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// This loss combines a `Sigmoid` layer and the `BCELoss` in one single +/// class. This version is more numerically stable than using a plain `Sigmoid` +/// followed by a `BCELoss` as, by combining the operations into one layer, +/// we take advantage of the log-sum-exp trick for numerical stability. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.BCEWithLogitsLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// BCEWithLogitsLoss +/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API BCEWithLogitsLossImpl + : public Cloneable { + explicit BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `BCEWithLogitsLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + BCEWithLogitsLossOptions options; + + /// A manual rescaling weight given to the loss of each batch element. + Tensor weight; + + /// A weight of positive examples. + Tensor pos_weight; +}; + +/// A `ModuleHolder` subclass for `BCEWithLogitsLossImpl`. +/// See the documentation for `BCEWithLogitsLossImpl` class to learn what +/// methods it provides, and examples of how to use `BCEWithLogitsLoss` with +/// `torch::nn::BCEWithLogitsLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(BCEWithLogitsLoss); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..9bc0b7f9e7fc45208f2b58a79fdc30d6c46b8b8e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h @@ -0,0 +1,198 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Layer Normalization over a mini-batch of inputs as described in +/// the paper `Layer Normalization`_ . +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LayerNorm to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LayerNormOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LayerNorm model(LayerNormOptions({2, +/// 2}).elementwise_affine(false).eps(2e-5)); +/// ``` +class TORCH_API LayerNormImpl : public torch::nn::Cloneable { + public: + LayerNormImpl(std::vector normalized_shape) + : LayerNormImpl(LayerNormOptions(normalized_shape)) {} + explicit LayerNormImpl(LayerNormOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `LayerNorm` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Applies layer normalization over a mini-batch of inputs as described in + /// the paper `Layer Normalization`_ . + /// + /// The mean and standard-deviation are calculated separately over the last + /// certain number dimensions which have to be of the shape specified by + /// input `normalized_shape`. + /// + /// `Layer Normalization`: https://arxiv.org/abs/1607.06450 + Tensor forward(const Tensor& input); + + /// The options with which this module was constructed. + LayerNormOptions options; + + /// The learned weight. + /// Initialized to ones if the `elementwise_affine` option is set to `true` + /// upon construction. + Tensor weight; + + /// The learned bias. + /// Initialized to zeros `elementwise_affine` option is set to `true` upon + /// construction. + Tensor bias; +}; + +/// A `ModuleHolder` subclass for `LayerNormImpl`. +/// See the documentation for `LayerNormImpl` class to learn what methods it +/// provides, and examples of how to use `LayerNorm` with +/// `torch::nn::LayerNormOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LayerNorm); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies local response normalization over an input signal composed +/// of several input planes, where channels occupy the second dimension. +/// Applies normalization across channels. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LocalResponseNorm to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LocalResponseNormOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LocalResponseNorm +/// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); +/// ``` +class TORCH_API LocalResponseNormImpl + : public Cloneable { + public: + LocalResponseNormImpl(int64_t size) + : LocalResponseNormImpl(LocalResponseNormOptions(size)) {} + explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `LocalResponseNormImpl` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + LocalResponseNormOptions options; +}; + +/// A `ModuleHolder` subclass for `LocalResponseNormImpl`. +/// See the documentation for `LocalResponseNormImpl` class to learn what +/// methods it provides, and examples of how to use `LocalResponseNorm` with +/// `torch::nn::LocalResponseNormOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(LocalResponseNorm); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10)); +/// ``` +class TORCH_API CrossMapLRN2dImpl + : public torch::nn::Cloneable { + public: + CrossMapLRN2dImpl(int64_t size) + : CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {} + explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_) + : options(options_) {} + + void reset() override; + + /// Pretty prints the `CrossMapLRN2d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + torch::Tensor forward(const torch::Tensor& input); + + CrossMapLRN2dOptions options; +}; + +/// A `ModuleHolder` subclass for `CrossMapLRN2dImpl`. +/// See the documentation for `CrossMapLRN2dImpl` class to learn what methods it +/// provides, and examples of how to use `CrossMapLRN2d` with +/// `torch::nn::CrossMapLRN2dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(CrossMapLRN2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Group Normalization over a mini-batch of inputs as described in +/// the paper `Group Normalization`_ . +/// See https://pytorch.org/docs/main/nn.html#torch.nn.GroupNorm to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::GroupNormOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false)); +/// ``` +class TORCH_API GroupNormImpl : public torch::nn::Cloneable { + public: + GroupNormImpl(int64_t num_groups, int64_t num_channels) + : GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {} + explicit GroupNormImpl(const GroupNormOptions& options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `GroupNorm` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this module was constructed. + GroupNormOptions options; + + /// The learned weight. + Tensor weight; + + /// The learned bias. + Tensor bias; +}; + +/// A `ModuleHolder` subclass for `GroupNormImpl`. +/// See the documentation for `GroupNormImpl` class to learn what methods it +/// provides, and examples of how to use `GroupNorm` with +/// `torch::nn::GroupNormOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(GroupNorm); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pooling.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..0fac60edbcde40948f7dce1f0cea94a26fab2506 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pooling.h @@ -0,0 +1,779 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace nn { + +/// Base class for all (dimension-specialized) avgpool modules. +template +class TORCH_API AvgPoolImpl : public torch::nn::Cloneable { + public: + AvgPoolImpl(ExpandingArray kernel_size) + : AvgPoolImpl(AvgPoolOptions(kernel_size)) {} + explicit AvgPoolImpl(const AvgPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `AvgPool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + AvgPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AvgPool1dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AvgPool1d model(AvgPool1dOptions(3).stride(2)); +/// ``` +class TORCH_API AvgPool1dImpl : public AvgPoolImpl<1, AvgPool1dImpl> { + public: + using AvgPoolImpl<1, AvgPool1dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool1dImpl`. +/// See the documentation for `AvgPool1dImpl` class to learn what methods it +/// provides, and examples of how to use `AvgPool1d` with +/// `torch::nn::AvgPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(AvgPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AvgPool2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2})); +/// ``` +class TORCH_API AvgPool2dImpl : public AvgPoolImpl<2, AvgPool2dImpl> { + public: + using AvgPoolImpl<2, AvgPool2dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool2dImpl`. +/// See the documentation for `AvgPool2dImpl` class to learn what methods it +/// provides, and examples of how to use `AvgPool2d` with +/// `torch::nn::AvgPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(AvgPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AvgPool3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AvgPool3d model(AvgPool3dOptions(5).stride(2)); +/// ``` +class TORCH_API AvgPool3dImpl : public AvgPoolImpl<3, AvgPool3dImpl> { + public: + using AvgPoolImpl<3, AvgPool3dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool3dImpl`. +/// See the documentation for `AvgPool3dImpl` class to learn what methods it +/// provides, and examples of how to use `AvgPool3d` with +/// `torch::nn::AvgPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(AvgPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) maxpool modules. +template +class TORCH_API MaxPoolImpl : public torch::nn::Cloneable { + public: + MaxPoolImpl(ExpandingArray kernel_size) + : MaxPoolImpl(MaxPoolOptions(kernel_size)) {} + explicit MaxPoolImpl(const MaxPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `MaxPool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + MaxPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxPool1dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxPool1d model(MaxPool1dOptions(3).stride(2)); +/// ``` +class TORCH_API MaxPool1dImpl : public MaxPoolImpl<1, MaxPool1dImpl> { + public: + using MaxPoolImpl<1, MaxPool1dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool1d` later. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool1dImpl`. +/// See the documentation for `MaxPool1dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxPool1d` with +/// `torch::nn::MaxPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxPool2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2})); +/// ``` +class TORCH_API MaxPool2dImpl : public MaxPoolImpl<2, MaxPool2dImpl> { + public: + using MaxPoolImpl<2, MaxPool2dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool2d` later. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool2dImpl`. +/// See the documentation for `MaxPool2dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxPool2d` with +/// `torch::nn::MaxPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxPool3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxPool3d model(MaxPool3dOptions(3).stride(2)); +/// ``` +class TORCH_API MaxPool3dImpl : public MaxPoolImpl<3, MaxPool3dImpl> { + public: + using MaxPoolImpl<3, MaxPool3dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool3d` later. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool3dImpl`. +/// See the documentation for `MaxPool3dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxPool3d` with +/// `torch::nn::MaxPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) adaptive maxpool modules. +template +class TORCH_API AdaptiveMaxPoolImpl : public torch::nn::Cloneable { + public: + AdaptiveMaxPoolImpl(output_size_t output_size) + : AdaptiveMaxPoolImpl( + AdaptiveMaxPoolOptions(output_size)) {} + explicit AdaptiveMaxPoolImpl( + const AdaptiveMaxPoolOptions& options_) + : options(options_) {} + + void reset() override{}; + + /// Pretty prints the `AdaptiveMaxPool{1,2,3}d` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::AdaptiveMaxPool" << D << "d" + << "(output_size=" << options.output_size() << ")"; + } + + /// The options with which this `Module` was constructed. + AdaptiveMaxPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive maxpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool1d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveMaxPool1dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool1d model(AdaptiveMaxPool1dOptions(3)); +/// ``` +class TORCH_API AdaptiveMaxPool1dImpl + : public AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl> { + public: + using AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>:: + AdaptiveMaxPoolImpl; + + Tensor forward(const Tensor& input); + + /// Returns the indices along with the outputs. + /// Useful to pass to nn.MaxUnpool1d. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveMaxPool1dImpl`. +/// See the documentation for `AdaptiveMaxPool1dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool1d` with +/// `torch::nn::AdaptiveMaxPool1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveMaxPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive maxpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveMaxPool2dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2})); +/// ``` +class TORCH_API AdaptiveMaxPool2dImpl : public AdaptiveMaxPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveMaxPool2dImpl> { + public: + using AdaptiveMaxPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveMaxPool2dImpl>::AdaptiveMaxPoolImpl; + + Tensor forward(const Tensor& input); + + /// Returns the indices along with the outputs. + /// Useful to pass to nn.MaxUnpool2d. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveMaxPool2dImpl`. +/// See the documentation for `AdaptiveMaxPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool2d` with +/// `torch::nn::AdaptiveMaxPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveMaxPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive maxpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveMaxPool3dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool3d model(AdaptiveMaxPool3dOptions(3)); +/// ``` +class TORCH_API AdaptiveMaxPool3dImpl : public AdaptiveMaxPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveMaxPool3dImpl> { + public: + using AdaptiveMaxPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveMaxPool3dImpl>::AdaptiveMaxPoolImpl; + + Tensor forward(const Tensor& input); + + /// Returns the indices along with the outputs. + /// Useful to pass to nn.MaxUnpool3d. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveMaxPool3dImpl`. +/// See the documentation for `AdaptiveMaxPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool3d` with +/// `torch::nn::AdaptiveMaxPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveMaxPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) adaptive avgpool modules. +template +class TORCH_API AdaptiveAvgPoolImpl : public torch::nn::Cloneable { + public: + AdaptiveAvgPoolImpl(output_size_t output_size) + : AdaptiveAvgPoolImpl( + AdaptiveAvgPoolOptions(output_size)) {} + explicit AdaptiveAvgPoolImpl( + const AdaptiveAvgPoolOptions& options_) + : options(options_) {} + + void reset() override {} + + /// Pretty prints the `AdaptiveAvgPool{1,2,3}d` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::AdaptiveAvgPool" << D << "d" + << "(output_size=" << options.output_size() << ")"; + } + + /// The options with which this `Module` was constructed. + AdaptiveAvgPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive avgpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool1d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool1d model(AdaptiveAvgPool1dOptions(5)); +/// ``` +class TORCH_API AdaptiveAvgPool1dImpl + : public AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl> { + public: + using AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>:: + AdaptiveAvgPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveAvgPool1dImpl`. +/// See the documentation for `AdaptiveAvgPool1dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool1d` with +/// `torch::nn::AdaptiveAvgPool1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveAvgPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive avgpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2})); +/// ``` +class TORCH_API AdaptiveAvgPool2dImpl : public AdaptiveAvgPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveAvgPool2dImpl> { + public: + using AdaptiveAvgPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveAvgPool2dImpl>::AdaptiveAvgPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveAvgPool2dImpl`. +/// See the documentation for `AdaptiveAvgPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool2d` with +/// `torch::nn::AdaptiveAvgPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveAvgPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive avgpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool3d model(AdaptiveAvgPool3dOptions(3)); +/// ``` +class TORCH_API AdaptiveAvgPool3dImpl : public AdaptiveAvgPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveAvgPool3dImpl> { + public: + using AdaptiveAvgPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveAvgPool3dImpl>::AdaptiveAvgPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveAvgPool3dImpl`. +/// See the documentation for `AdaptiveAvgPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool3d` with +/// `torch::nn::AdaptiveAvgPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveAvgPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) maxunpool modules. +template +class TORCH_API MaxUnpoolImpl : public torch::nn::Cloneable { + public: + MaxUnpoolImpl(ExpandingArray kernel_size) + : MaxUnpoolImpl(MaxUnpoolOptions(kernel_size)) {} + explicit MaxUnpoolImpl(const MaxUnpoolOptions& options_); + + void reset() override; + + /// Pretty prints the `MaxUnpool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + MaxUnpoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxunpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxUnpool1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxUnpool1d model(MaxUnpool1dOptions(3).stride(2).padding(1)); +/// ``` +class TORCH_API MaxUnpool1dImpl : public MaxUnpoolImpl<1, MaxUnpool1dImpl> { + public: + using MaxUnpoolImpl<1, MaxUnpool1dImpl>::MaxUnpoolImpl; + Tensor forward( + const Tensor& input, + const Tensor& indices, + const std::optional>& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) +}; + +/// A `ModuleHolder` subclass for `MaxUnpool1dImpl`. +/// See the documentation for `MaxUnpool1dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxUnpool1d` with +/// `torch::nn::MaxUnpool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxUnpool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxunpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxUnpool2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxUnpool2d model(MaxUnpool2dOptions(3).stride(2).padding(1)); +/// ``` +class TORCH_API MaxUnpool2dImpl : public MaxUnpoolImpl<2, MaxUnpool2dImpl> { + public: + using MaxUnpoolImpl<2, MaxUnpool2dImpl>::MaxUnpoolImpl; + Tensor forward( + const Tensor& input, + const Tensor& indices, + const std::optional>& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) +}; + +/// A `ModuleHolder` subclass for `MaxUnpool2dImpl`. +/// See the documentation for `MaxUnpool2dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxUnpool2d` with +/// `torch::nn::MaxUnpool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxUnpool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxunpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxUnpool3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxUnpool3d model(MaxUnpool3dOptions(3).stride(2).padding(1)); +/// ``` +class TORCH_API MaxUnpool3dImpl : public MaxUnpoolImpl<3, MaxUnpool3dImpl> { + public: + using MaxUnpoolImpl<3, MaxUnpool3dImpl>::MaxUnpoolImpl; + Tensor forward( + const Tensor& input, + const Tensor& indices, + const std::optional>& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) +}; + +/// A `ModuleHolder` subclass for `MaxUnpool3dImpl`. +/// See the documentation for `MaxUnpool3dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxUnpool3d` with +/// `torch::nn::MaxUnpool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxUnpool3d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies fractional maxpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.FractionalMaxPool2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::FractionalMaxPool2dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// FractionalMaxPool2d model(FractionalMaxPool2dOptions(5).output_size(1)); +/// ``` +class TORCH_API FractionalMaxPool2dImpl + : public torch::nn::Cloneable { + public: + FractionalMaxPool2dImpl(ExpandingArray<2> kernel_size) + : FractionalMaxPool2dImpl(FractionalMaxPool2dOptions(kernel_size)) {} + explicit FractionalMaxPool2dImpl(FractionalMaxPool2dOptions options_); + + void reset() override; + + /// Pretty prints the `FractionalMaxPool2d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool2d` later. + std::tuple forward_with_indices(const Tensor& input); + + /// The options with which this `Module` was constructed. + FractionalMaxPool2dOptions options; + + Tensor _random_samples; +}; + +/// A `ModuleHolder` subclass for `FractionalMaxPool2dImpl`. +/// See the documentation for `FractionalMaxPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `FractionalMaxPool2d` with +/// `torch::nn::FractionalMaxPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(FractionalMaxPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies fractional maxpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.FractionalMaxPool3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::FractionalMaxPool3dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// FractionalMaxPool3d model(FractionalMaxPool3dOptions(5).output_size(1)); +/// ``` +class TORCH_API FractionalMaxPool3dImpl + : public torch::nn::Cloneable { + public: + FractionalMaxPool3dImpl(ExpandingArray<3> kernel_size) + : FractionalMaxPool3dImpl(FractionalMaxPool3dOptions(kernel_size)) {} + explicit FractionalMaxPool3dImpl(FractionalMaxPool3dOptions options_); + + void reset() override; + + /// Pretty prints the `FractionalMaxPool3d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool3d` later. + std::tuple forward_with_indices(const Tensor& input); + + /// The options with which this `Module` was constructed. + FractionalMaxPool3dOptions options; + + Tensor _random_samples; +}; + +/// A `ModuleHolder` subclass for `FractionalMaxPool3dImpl`. +/// See the documentation for `FractionalMaxPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `FractionalMaxPool3d` with +/// `torch::nn::FractionalMaxPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(FractionalMaxPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) lppool modules. +template +class TORCH_API LPPoolImpl : public torch::nn::Cloneable { + public: + LPPoolImpl(double norm_type, ExpandingArray kernel_size) + : LPPoolImpl(LPPoolOptions(norm_type, kernel_size)) {} + explicit LPPoolImpl(const LPPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `LPPool{1,2}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + LPPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LPPool1d function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LPPool1dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LPPool1d model(LPPool1dOptions(1, 2).stride(5).ceil_mode(true)); +/// ``` +class TORCH_API LPPool1dImpl : public LPPoolImpl<1, LPPool1dImpl> { + public: + using LPPoolImpl<1, LPPool1dImpl>::LPPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `LPPool1dImpl`. +/// See the documentation for `LPPool1dImpl` class to learn what methods it +/// provides, and examples of how to use `LPPool1d` with +/// `torch::nn::LPPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LPPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LPPool2d function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LPPool2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LPPool2d model(LPPool2dOptions(1, std::vector({3, 4})).stride({5, +/// 6}).ceil_mode(true)); +/// ``` +class TORCH_API LPPool2dImpl : public LPPoolImpl<2, LPPool2dImpl> { + public: + using LPPoolImpl<2, LPPool2dImpl>::LPPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `LPPool2dImpl`. +/// See the documentation for `LPPool2dImpl` class to learn what methods it +/// provides, and examples of how to use `LPPool2d` with +/// `torch::nn::LPPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LPPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LPPool3d function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LPPool3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LPPool3d model(LPPool3dOptions(1, std::vector({3, 4, 5})).stride( +/// {5, 6, 7}).ceil_mode(true)); +/// ``` +class TORCH_API LPPool3dImpl : public LPPoolImpl<3, LPPool3dImpl> { + public: + using LPPoolImpl<3, LPPool3dImpl>::LPPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `LPPool3dImpl`. +/// See the documentation for `LPPool3dImpl` class to learn what methods it +/// provides, and examples of how to use `LPPool3d` with +/// `torch::nn::LPPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LPPool3d); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/upsampling.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/upsampling.h new file mode 100644 index 0000000000000000000000000000000000000000..8520bf632f83e6f42228de2152efd90e458845bc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/upsampling.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch { +namespace nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Upsample ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D +/// (volumetric) data. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Upsample to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::UpsampleOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Upsample +/// model(UpsampleOptions().scale_factor({3}).mode(torch::kLinear).align_corners(false)); +/// ``` +class TORCH_API UpsampleImpl : public Cloneable { + public: + explicit UpsampleImpl(const UpsampleOptions& options_ = {}); + + void reset() override; + + /// Pretty prints the `Upsample` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this `Module` was constructed. + UpsampleOptions options; +}; + +/// A `ModuleHolder` subclass for `UpsampleImpl`. +/// See the documentation for `UpsampleImpl` class to learn what methods it +/// provides, and examples of how to use `Upsample` with +/// `torch::nn::UpsampleOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Upsample); + +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/parallel/data_parallel.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/parallel/data_parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..22f8f678a8e74653f22dc70e6fa1c8ae6bc3b3ee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/parallel/data_parallel.h @@ -0,0 +1,297 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +namespace { + +// Note [Replicating Modules] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Module replication is implemented in the following two steps: +// 1) create a module replica on each destination device using Module.clone(). +// 2) manually add a gradient edge pointing from every parameter X in every +// module replica to the same parameter X in the original module, using +// ReduceAdd as the grad_fn. +// +// ReduceAdd can ONLY be used during the backward pass of data parallel. Forward +// pass cannot use this function as it does not setup gradient function and +// history at all. Do NOT try to use ReduceAdd for any other purposes. +// +// NB: An alternative is to add Broadcast and ReduceAddCoalesce to +// torch/csrc/autograd/functions/comm.cpp as normal autograd functions, +// implement a Replicatable (like cloneable) class and add it as a friend class +// in Module.h. In the forward pass, the Replicatable could use the Broadcast +// function to replicate every module parameter and set gradient functions using +// ReduceAddCoalesce (like how it is implemented in Python). However, unlike in +// Python, where changes to Linear._parameters["weight"] would also apply to +// Linear.weight (using Linear as an example), Linear.weight and +// Linear.parameters_["weight"] are two tensor objects pointing to the same +// TensorImpl. Assigning a new tensor to Linear.parameters_["weight"] will not +// change Linear.weight. To make this work, we will have to: +// 1) force every module to also inherit from Replicatable +// 2) force every module to implement an additional function, e.g., +// Replicatable::load_params(), to pick up changes from parameters_ to their +// own member fields. +// This will be an overkill as Replicatable will only be used in data_parallel, +// not even ddp. + +// Autograd function for the replicate step in data parallel. This is only used +// in data parallel, and should not be exposed as a user API. +struct ReduceAdd : public autograd::Node { + explicit ReduceAdd(const at::Device& destination_device) + : destination_device_(destination_device){}; + ~ReduceAdd() override {} + + autograd::variable_list apply(autograd::variable_list&& inputs) override { + TORCH_CHECK( + !torch::autograd::compute_requires_grad(inputs), + "ReduceAdd can only be used during the backward pass of data parallel."); + + Tensor output = torch::zeros_like(inputs[0], {destination_device_}); + + for (auto& input : inputs) { + TORCH_CHECK( + input.sizes() == inputs[0].sizes(), + "All inputs of ReduceAdd must have the same size, but got ", + input.sizes(), + " and ", + inputs[0].sizes()); + + TORCH_CHECK( + input.dtype() == inputs[0].dtype(), + "All inputs of ReduceAdd must have the same dtype, but got ", + input.dtype(), + " and ", + inputs[0].dtype()); + + // TODO: use nccl reduce + output.add_(input.to(destination_device_)); + } + + return {output}; + } + + private: + at::Device destination_device_; +}; + +} // namespace + +// A friend function to Module, it recursively sets gradient edges pointing from +// every parameter X in every module replica to the same parameter X in the +// original module. See [Replicating Modules] +template +void replicate_grad_edges( + const std::shared_ptr& module, + const std::vector>& replicas, + const std::vector& devices) { + for (auto& parameter : module->named_parameters(/*recurse=*/false)) { + auto grad_fn = std::make_shared((*parameter).device()); + grad_fn->set_next_edges(autograd::collect_next_edges(*parameter)); + + for (const auto i : c10::irange(devices.size())) { + autograd::set_history(replicas[i]->parameters_[parameter.key()], grad_fn); + } + } + + for (auto& buffer : module->named_buffers(/*recurse=*/false)) { + if (buffer.value().requires_grad()) { + auto grad_fn = std::make_shared((*buffer).device()); + grad_fn->set_next_edges(autograd::collect_next_edges(*buffer)); + + for (const auto i : c10::irange(devices.size())) { + autograd::set_history(replicas[i]->buffers_[buffer.key()], grad_fn); + } + } + } + + for (auto& child : module->children_) { + std::vector> child_replicas; + child_replicas.reserve(devices.size()); + for (auto& replica : replicas) { + child_replicas.push_back(replica->children_[child.key()]); + } + + // recursively set gradient edges for all children + replicate_grad_edges(*child, child_replicas, devices); + } +} + +namespace parallel { + +/// Replicates a module on the given list of devices. +/// A replica is created by calling `clone()` on the module. For this, the +/// module must inherit from `nn::Cloneable`, or define its own `clone()` +/// method, which is expected to perform a deep copy of the module. +template +std::vector> replicate( + const std::shared_ptr& module, + const std::vector& devices) { + std::vector> replicas; + replicas.reserve(devices.size()); + for (const auto& device : devices) { + replicas.push_back( + std::dynamic_pointer_cast(module->clone(device))); + } + // Configure gradient edges to point from replcia parameters to original + // module parameters. See [Replicating Modules] + replicate_grad_edges(module, replicas, devices); + return replicas; +} + +/// Replicates a module holder on the given list of devices. +/// This method allows calling `replicate()` with a module holder, such as +/// `Linear`. +template +std::vector> replicate( + const ModuleHolder& module, + const std::vector& devices) { + auto ptrs = replicate(module.ptr(), devices); + return std::vector>(ptrs.begin(), ptrs.end()); +} + +/// Applies the given inputs to the given modules in a parallel fashion. +/// Conceptually, a thread is spawned for each `(module, input)` pair, in which +/// `forward()` is called on the module with its corresponding input. The +/// outputs of the individual calls are stored in a vector and returned. +/// +/// The first exception caught by any thread is stashed and rethrown after all +/// threads have completed their operation. +/// +/// Further remarks: +/// 1. The length of the module container must match the length of the inputs. +/// 2. If a list of devices is supplied, it must match the list of modules in +/// length. Each device will be set to the current default device during the +/// invocation of the respective module. This means any tensors allocated on the +/// default device inside the module will be constructed on this device. +template +std::vector parallel_apply( + std::vector& modules, + const std::vector& inputs, + const std::optional>& devices = std::nullopt) { + TORCH_CHECK( + modules.size() == inputs.size(), "Must have as many inputs as modules"); + if (devices) { + TORCH_CHECK( + modules.size() == devices->size(), + "Must have as many devices as modules"); + } + + std::vector outputs(modules.size()); + std::mutex mutex; + + // std::exception_ptr can be passed between threads: + // > An instance of std::exception_ptr may be passed to another function, + // > possibly on another thread, where the exception may be rethrown [...]. + // https://en.cppreference.com/w/cpp/error/exception_ptr + std::exception_ptr exception; + + at::parallel_for( + /*begin=*/0, + /*end=*/modules.size(), + /*grain_size=*/1, + [&modules, &inputs, &devices, &outputs, &mutex, &exception]( + int64_t index, int64_t stop) { + for (; index < stop; ++index) { + try { + auto output = modules[index]->forward(inputs[index]); + output = + output.to(devices ? (*devices)[index] : inputs[index].device()); + std::lock_guard lock(mutex); + outputs[index] = output; + } catch (...) { + std::lock_guard lock(mutex); + if (!exception) { + exception = std::current_exception(); + } + } + } + }); + + if (exception) { + std::rethrow_exception(exception); + } + + return outputs; +} + +/// Evaluates `module(input)` in parallel across the given `devices`. If +/// `devices` is not supplied, the invocation is parallelized across all +/// available CUDA devices. If `output_device` is supplied, the final, combined +/// tensor will be placed on this device. If not, it defaults to the first +/// device in `devices`. +/// +/// In detail, this method performs the following four distinct steps: +/// 1. *Scatter* the input to the given devices, +/// 2. *Replicate* (deep clone) the model on each device, +/// 3. *Evaluate* each module with its input on its device, +/// 4. *Gather* the outputs of each replica into a single output tensor, located +/// on the `output_device`. +template +Tensor data_parallel( + ModuleType module, + Tensor input, + std::optional> devices = std::nullopt, + std::optional output_device = std::nullopt, + int64_t dim = 0) { + if (!devices) { + const auto device_count = torch::cuda::device_count(); + TORCH_CHECK( + device_count > 0, "Expected at least one CUDA device to be available"); + devices = std::vector(); + devices->reserve(device_count); + for (const auto index : c10::irange(device_count)) { + devices->emplace_back(kCUDA, static_cast(index)); + } + } + if (!output_device) { + output_device = devices->front(); + } + + if (devices->size() == 1) { + module->to(devices->front()); + input = input.to(devices->front()); + return module->forward(std::move(input)).to(*output_device); + } + + autograd::Scatter scatter(*devices, /*chunk_sizes=*/nullopt, dim); + auto scattered_inputs = fmap(scatter.apply({std::move(input)})); + // Input tensor might not be big enough to scale across all available devices + if (scattered_inputs.size() < devices->size()) { + devices->resize( + scattered_inputs.size(), + Device(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)); + } + + auto replicas = replicate(module, *devices); + auto outputs = parallel_apply(replicas, scattered_inputs, *devices); + return autograd::Gather(*output_device, dim) + .apply(fmap(std::move(outputs))) + .front(); +} + +} // namespace parallel +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/clip_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..8a2a569c03335cf60e3be785a5bebb4c821d237b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/clip_grad.h @@ -0,0 +1,147 @@ +#pragma once + +#include + +#include + +namespace torch { +namespace nn { +namespace utils { + +// Clips gradient norm of a vector of Tensors. +// See +// https://pytorch.org/docs/stable/nn.html?highlight=clip_grad_norm#torch.nn.utils.clip_grad_norm_ +// for more details about this module. +// +// Difference with the python version: unlike the python version, even when +// skipping the finiteness checks (error_if_nonfinite = false), this function +// will introduce a device <=> CPU synchronization (for devices where that makes +// sense!) in order to return a CPU-side `double`. This C++ version therefore +// cannot be run fully asynchronously w.r.t. the device of the gradients. +inline double clip_grad_norm_( + const std::vector& parameters, + double max_norm, + double norm_type = 2.0, + bool error_if_nonfinite = false) { + std::vector params_with_grad; + + for (const auto& param : parameters) { + auto& grad = param.grad(); + if (grad.defined()) { + params_with_grad.push_back(param); + } + } + + if (params_with_grad.empty()) { + return 0.0; + } + + Tensor total_norm_tensor; + if (norm_type == std::numeric_limits::infinity()) { + std::vector norms; + norms.reserve(params_with_grad.size()); + + for (const auto& param : params_with_grad) { + norms.emplace_back(param.grad().data().abs().max()); + } + total_norm_tensor = + (norms.size() == 1) ? norms[0] : torch::max(torch::stack(norms)); + } else if (norm_type == 0) { + total_norm_tensor = + torch::full({}, static_cast(params_with_grad.size())); + } else { + std::vector norms; + norms.reserve(params_with_grad.size()); + + for (const auto& param : params_with_grad) { + norms.emplace_back(param.grad().data().norm(norm_type)); + } + total_norm_tensor = + (norms.size() == 1) ? norms[0] : torch::stack(norms).norm(norm_type); + } + + // When possible (ie when skipping the finiteness check), we avoid + // synchronizing the CPU and the gradients' device until the very end to + // preserve async execution on the device. When checking for finite-ness, this + // optional ensures we only sync once. + std::optional total_norm = std::nullopt; + if (error_if_nonfinite) { + total_norm = total_norm_tensor.item().toDouble(); + TORCH_CHECK( + std::isfinite(*total_norm), + "The total norm of order ", + norm_type, + " for gradients from `parameters` ", + "is non-finite, so it cannot be clipped. To disable this error and scale ", + "the gradients with the non-finite norm anyway, set ", + "`error_if_nonfinite=false`"); + } + + auto clip_coef = max_norm / (total_norm_tensor + 1e-6); + auto clip_coef_clamped = + torch::clamp(clip_coef, std::nullopt /* min */, 1.0 /* max */); + for (auto& param : params_with_grad) { + param.grad().data().mul_(clip_coef_clamped); + } + + if (!total_norm.has_value()) { + total_norm = total_norm_tensor.item().toDouble(); + } + return *total_norm; +} + +// A wrapper around clip_grad_norm_ that allows us to call the function with a +// braced-init-list of Tensors. +inline double clip_grad_norm_( + std::initializer_list parameters, + double max_norm, + double norm_type = 2.0, + bool error_if_nonfinite = false) { + return clip_grad_norm_( + std::vector(parameters), max_norm, norm_type, error_if_nonfinite); +} + +// A wrapper around clip_grad_norm_ that allows us to call the function with a +// single Tensor. +inline double clip_grad_norm_( + Tensor parameter, + double max_norm, + double norm_type = 2.0, + bool error_if_nonfinite = false) { + std::vector params = {std::move(parameter)}; + return clip_grad_norm_( + std::move(params), max_norm, norm_type, error_if_nonfinite); +} + +// Clips gradient of an iterable of parameters at specified value. +// Gradients are modified in-place. +// See https://pytorch.org/docs/stable/nn.html#clip-grad-value +// for more details about this module. +inline void clip_grad_value_( + const std::vector& parameters, + double clip_value) { + for (const auto& param : parameters) { + if (param.grad().defined()) { + param.grad().data().clamp_(-clip_value, clip_value); + } + } +} + +// A wrapper around clip_grad_value_ that allows us to call the function with a +// braced-init-list of Tensors. +inline void clip_grad_value_( + std::initializer_list parameters, + double clip_value) { + clip_grad_value_(std::vector(parameters), clip_value); +} + +// A wrapper around clip_grad_value_ that allows us to call the function with a +// single Tensor. +inline void clip_grad_value_(Tensor parameter, double clip_value) { + std::vector params = {std::move(parameter)}; + clip_grad_value_(std::move(params), clip_value); +} + +} // namespace utils +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/convert_parameters.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/convert_parameters.h new file mode 100644 index 0000000000000000000000000000000000000000..b8bfee33473f2a4ee6cb4acf45d5940b7f06850d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/convert_parameters.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include + +namespace torch { +namespace nn { +namespace utils { + +// This helper function is to check if the parameters are located +// in the same device. Currently, the conversion between model parameters +// and single vector form is not supported for multiple allocations, +// e.g. parameters in different GPUs, or mixture of CPU/GPU. +inline std::optional _check_param_device( + const torch::Tensor& param, + std::optional old_param_device) { + // Meet the first parameter + if (old_param_device == std::nullopt) { + old_param_device = param.is_cuda() ? param.get_device() : -1; + } else { + bool warn = false; + if (param.is_cuda()) { // Check if in same GPU + warn = (param.get_device() != old_param_device.value()); + } else { // Check if in CPU + warn = (old_param_device.value() != -1); + } + if (warn) { + TORCH_CHECK( + false, + "Found two parameters on different devices, ", + "this is currently not supported."); + } + } + + return old_param_device; +} + +// Convert parameters to one vector +inline torch::Tensor parameters_to_vector( + const std::vector& parameters) { + std::optional param_device; + + std::vector vec; + vec.reserve(parameters.size()); + + for (const torch::Tensor& param : parameters) { + // Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device); + + vec.push_back(param.view(-1)); + } + + return torch::cat(vec); +} + +// Convert one vector to the parameters +inline void vector_to_parameters( + const torch::Tensor& vec, + const std::vector& parameters) { + // Flag for the device where the parameter is located + std::optional param_device; + + // Pointer for slicing the vector for each parameter + int64_t pointer = 0; + for (const torch::Tensor& param : parameters) { + // Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device); + + // The length of the parameter + auto num_param = param.numel(); + // Slice the vector, reshape it, and replace the old data of the parameter + param.set_data( + vec.slice(0, pointer, pointer + num_param).view_as(param).data()); + + // Increment the pointer + pointer += num_param; + } +} + +} // namespace utils +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/rnn.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/rnn.h new file mode 100644 index 0000000000000000000000000000000000000000..6f2a68984c80ac9c6355664bd86d624f00cd3203 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/rnn.h @@ -0,0 +1,354 @@ +#pragma once + +#include +#include + +#include + +namespace torch { +namespace nn { +namespace utils { +namespace rnn { + +inline Tensor invert_permutation(const Tensor& permutation) { + if (!permutation.defined()) { + return torch::Tensor(); + } + Tensor output = + torch::empty_like(permutation, torch::MemoryFormat::Contiguous); + output.scatter_( + 0, + permutation, + torch::arange(0, permutation.numel(), permutation.device())); + return output; +} + +/// Holds the data and list of `batch_sizes` of a packed sequence. +/// +/// All RNN modules accept packed sequences as inputs. +/// +/// Note: +/// Instances of this class should never be created manually. They are meant +/// to be instantiated by functions like `pack_padded_sequence`. +/// +/// Batch sizes represent the number elements at each sequence step in +/// the batch, not the varying sequence lengths passed to +/// `pack_padded_sequence`. For instance, given data ``abc`` and ``x`` +/// the :class:`PackedSequence` would contain data ``axbc`` with +/// ``batch_sizes=[2,1,1]``. +/// +/// Attributes: +/// data (Tensor): Tensor containing packed sequence +/// batch_sizes (Tensor): Tensor of integers holding +/// information about the batch size at each sequence step +/// sorted_indices (Tensor, optional): Tensor of integers holding how this +/// :class:`PackedSequence` is constructed from sequences. +/// unsorted_indices (Tensor, optional): Tensor of integers holding how this +/// to recover the original sequences with correct order. +/// +/// .. note:: +/// `data` can be on arbitrary device and of arbitrary dtype. +/// `sorted_indices` and `unsorted_indices` must be ``torch::kInt64`` +/// tensors on the same device as `data`. +/// +/// However, `batch_sizes` should always be a CPU ``torch::kInt64`` tensor. +/// +/// This invariant is maintained throughout `PackedSequence` class, +/// and all functions that construct a `PackedSequence` in libtorch +/// (i.e., they only pass in tensors conforming to this constraint). +class PackedSequence { + public: + explicit PackedSequence( + Tensor data, + Tensor batch_sizes, + Tensor sorted_indices = {}, + Tensor unsorted_indices = {}) { + // NB: if unsorted_indices is provided, it should be the inverse permutation + // to sorted_indices. Don't assert it here because the PackedSequence ctor + // should only be used internally. + if (!unsorted_indices.defined()) { + unsorted_indices = invert_permutation(sorted_indices); + } + TORCH_CHECK( + batch_sizes.device().type() == kCPU, + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn::utils::rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence"); + data_ = std::move(data); + batch_sizes_ = std::move(batch_sizes); + sorted_indices_ = std::move(sorted_indices); + unsorted_indices_ = std::move(unsorted_indices); + } + + const Tensor& data() const { + return data_; + } + + const Tensor& batch_sizes() const { + return batch_sizes_; + } + + const Tensor& sorted_indices() const { + return sorted_indices_; + } + + const Tensor& unsorted_indices() const { + return unsorted_indices_; + } + + PackedSequence pin_memory() const { + // Why not convert `batch_sizes`? + // See NOTE [ device and dtype of a PackedSequence ] + return PackedSequence( + data_.pin_memory(), + batch_sizes_, + sorted_indices_.defined() ? sorted_indices_.pin_memory() : Tensor(), + unsorted_indices_.defined() ? unsorted_indices_.pin_memory() + : Tensor()); + } + + PackedSequence to(TensorOptions options) const { + // Performs dtype and/or device conversion on `data_`. + // + // If the ``data_`` Tensor already has the correct `torch::Dtype` + // and `torch::Device`, then ``self`` is returned. + // Otherwise, returns a copy with the desired configuration. + + // Why not convert `batch_sizes`? + // See NOTE [ device and dtype of a PackedSequence ] + Tensor data = data_.to(options); + if (data.is_same(data_)) { + return *this; + } else { + // Does not forward device or dtype args, device is set from data.device() + Tensor sorted_indices = sorted_indices_.defined() + ? sorted_indices_.to( + options.device(data.device()).dtype(sorted_indices_.dtype())) + : Tensor(); + Tensor unsorted_indices = unsorted_indices_.defined() + ? unsorted_indices_.to( + options.device(data.device()).dtype(unsorted_indices_.dtype())) + : Tensor(); + return PackedSequence( + std::move(data), + batch_sizes_, + std::move(sorted_indices), + std::move(unsorted_indices)); + } + } + + PackedSequence cuda() const { + return to(kCUDA); + } + + PackedSequence cpu() const { + return to(kCPU); + } + + /// Returns true if `data_` stored on a gpu + bool is_cuda() const { + return data_.is_cuda(); + } + + /// Returns true if `data_` stored on in pinned memory + bool is_pinned() const { + return data_.is_pinned(); + } + + private: + Tensor data_; + Tensor batch_sizes_; + Tensor sorted_indices_; + Tensor unsorted_indices_; +}; + +/// Packs a Tensor containing padded sequences of variable length. +/// +/// `input` can be of size ``T x B x *`` where `T` is the length of the +/// longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and +/// ``*`` is any number of dimensions (including 0). If ``batch_first`` is +/// ``true``, ``B x T x *`` `input` is expected. +/// +/// For unsorted sequences, use `enforce_sorted = false`. If `enforce_sorted` is +/// ``true``, the sequences should be sorted by length in a decreasing order, +/// i.e. +/// ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the +/// shortest one. +/// +/// Note: +/// This function accepts any input that has at least two dimensions. You +/// can apply it to pack the labels, and use the output of the RNN with +/// them to compute the loss directly. A Tensor can be retrieved from +/// a `PackedSequence` object by calling its ``.data()`` function. +/// +/// Arguments: +/// input (Tensor): padded batch of variable length sequences. +/// lengths (Tensor): list of sequences lengths of each batch element. +/// batch_first (bool, optional): if ``true``, the input is expected in ``B +/// x T x *`` +/// format. Default: ``false``. +/// enforce_sorted (bool, optional): if ``true``, the input is expected to +/// contain sequences sorted by length in a decreasing order. If +/// ``false``, this condition is not checked. Default: ``true``. +/// +/// Returns: +/// a `PackedSequence` object +inline PackedSequence pack_padded_sequence( + Tensor input, + Tensor lengths, + bool batch_first = false, + bool enforce_sorted = true) { + lengths = lengths.to(kInt64); + Tensor sorted_indices; + if (enforce_sorted) { + sorted_indices = Tensor(); + } else { + std::tie(lengths, sorted_indices) = + torch::sort(lengths, /*dim=*/-1, /*descending=*/true); + sorted_indices = sorted_indices.to(input.device()); + int64_t batch_dim = batch_first ? 0 : 1; + input = input.index_select(batch_dim, sorted_indices); + } + + auto [data, batch_sizes] = + torch::_pack_padded_sequence(input, lengths, batch_first); + return PackedSequence( + std::move(data), std::move(batch_sizes), std::move(sorted_indices), {}); +} + +/// Pads a packed batch of variable length sequences. +/// +/// It is an inverse operation to `pack_padded_sequence`. +/// +/// The returned Tensor's data will be of size ``T x B x *``, where `T` is the +/// length of the longest sequence and `B` is the batch size. If ``batch_first`` +/// is true, the data will be transposed into ``B x T x *`` format. +/// +/// Batch elements will be ordered decreasingly by their length. +/// +/// Arguments: +/// sequence (PackedSequence): batch to pad +/// batch_first (bool, optional): if ``true``, the output will be in ``B x T +/// x *`` +/// format. +/// padding_value (double, optional): values for padded elements. +/// total_length (int64_t, optional): if specified, the output will be +/// padded to +/// have length `total_length`. This method will throw error +/// if `total_length` is less than the max sequence length in +/// `sequence`. +/// +/// Returns: +/// Tuple of Tensor containing the padded sequence, and a Tensor +/// containing the list of lengths of each sequence in the batch. +inline std::tuple pad_packed_sequence( + PackedSequence sequence, + bool batch_first = false, + double padding_value = 0.0, + std::optional total_length = torch::nullopt) { + int64_t max_seq_length = sequence.batch_sizes().size(0); + if (total_length.has_value()) { + int64_t total_length_val = total_length.value(); + TORCH_CHECK( + total_length_val >= max_seq_length, + "Expected total_length to be at least the length " + "of the longest sequence in input, but got " + "total_length=", + total_length_val, + " and max sequence length being ", + max_seq_length); + max_seq_length = total_length_val; + } + auto [padded_output, lengths] = torch::_pad_packed_sequence( + sequence.data(), + sequence.batch_sizes(), + batch_first, + padding_value, + max_seq_length); + const Tensor& unsorted_indices = sequence.unsorted_indices(); + if (unsorted_indices.defined()) { + int64_t batch_dim = batch_first ? 0 : 1; + return std::make_tuple( + padded_output.index_select(batch_dim, unsorted_indices), + lengths.index({unsorted_indices.cpu()})); + } + return std::make_tuple(padded_output, lengths); +} + +/// Pad a list of variable length Tensors with ``padding_value`` +/// +/// ``pad_sequence`` stacks a list of Tensors along a new dimension, +/// and pads them to equal length. For example, if the input is list of +/// sequences with size ``L x *`` and if batch_first is false, and ``T x B x *`` +/// otherwise. +/// +/// `B` is batch size. It is equal to the number of elements in ``sequences``. +/// `T` is length of the longest sequence. +/// `L` is length of the sequence. +/// `*` is any number of trailing dimensions, including none. +/// +/// Note: +/// This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` +/// where `T` is the length of the longest sequence. This function assumes +/// trailing dimensions and type of all the Tensors in sequences are same. +/// +/// Arguments: +/// sequences (torch::ArrayRef): list of variable length sequences. +/// batch_first (bool, optional): output will be in ``B x T x *`` if true, +/// or in +/// ``T x B x *`` otherwise +/// padding_value (double, optional): value for padded elements. Default: 0. +/// padding_side (str, optional): the side to pad the sequences on. Default: +/// "right". +/// +/// Returns: +/// Tensor of size ``T x B x *`` if `batch_first` is ``false``. +/// Tensor of size ``B x T x *`` otherwise +inline Tensor pad_sequence( + ArrayRef sequences, + bool batch_first = false, + double padding_value = 0, + c10::string_view padding_side = "right") { + return at::pad_sequence(sequences, batch_first, padding_value, padding_side); +} + +/// Packs a list of variable length Tensors +/// +/// ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is +/// the length of a sequence and `*` is any number of trailing dimensions, +/// including zero. +/// +/// For unsorted sequences, use `enforce_sorted = false`. If ``enforce_sorted`` +/// is ``true``, the sequences should be sorted in the order of decreasing +/// length. +/// +/// +/// Arguments: +/// sequences (torch::ArrayRef): A list of sequences of decreasing +/// length. enforce_sorted (bool, optional): if ``true``, checks that the +/// input +/// contains sequences sorted by length in a decreasing order. If +/// ``false``, this condition is not checked. Default: ``true``. +/// +/// Returns: +/// a `PackedSequence` object +inline PackedSequence pack_sequence( + ArrayRef sequences, + bool enforce_sorted = true) { + Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64); + for (const auto i : c10::irange(sequences.size())) { + lengths[i] = sequences[i].size(0); + } + return pack_padded_sequence( + at::pad_sequence(sequences), + std::move(lengths), + /*batch_first=*/false, + /*enforce_sorted=*/enforce_sorted); +} + +} // namespace rnn +} // namespace utils +} // namespace nn +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adagrad.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adagrad.h new file mode 100644 index 0000000000000000000000000000000000000000..4b2ff3c676b3d7b35485ace2b37c1545b6a66381 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adagrad.h @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace optim { + +struct TORCH_API AdagradOptions + : public OptimizerCloneableOptions { + AdagradOptions(double lr = 1e-2); + TORCH_ARG(double, lr) = 1e-2; + TORCH_ARG(double, lr_decay) = 0; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(double, initial_accumulator_value) = 0; + TORCH_ARG(double, eps) = 1e-10; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdagradOptions& lhs, + const AdagradOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API AdagradParamState + : public OptimizerCloneableParamState { + TORCH_ARG(torch::Tensor, sum); + TORCH_ARG(int64_t, step) = 0; + + public: + AdagradParamState() = default; + AdagradParamState(const AdagradParamState&) = default; + AdagradParamState& operator=(const AdagradParamState&) = default; + AdagradParamState(AdagradParamState&&) noexcept = default; + AdagradParamState& operator=(AdagradParamState&&) noexcept = default; + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdagradParamState& lhs, + const AdagradParamState& rhs); +}; + +class TORCH_API Adagrad : public Optimizer { + public: + explicit Adagrad( + std::vector param_groups, + AdagradOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK( + defaults.lr_decay() >= 0, + "Invalid lr_decay value: ", + defaults.lr_decay()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + defaults.initial_accumulator_value() >= 0, + "Invalid initial_accumulator_value value: ", + defaults.initial_accumulator_value()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + + for (const auto& group : param_groups_) { + for (const auto& p : group.params()) { + auto state = std::make_unique(); + state->step(0); + state->sum(torch::full_like( + p.data(), + defaults.initial_accumulator_value(), + at::MemoryFormat::Preserve)); + state_[p.unsafeGetTensorImpl()] = std::move(state); + } + } + } + + explicit Adagrad(std::vector params, AdagradOptions defaults = {}) + : Adagrad({OptimizerParamGroup(std::move(params))}, defaults) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Adagrad); + } +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adam.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adam.h new file mode 100644 index 0000000000000000000000000000000000000000..6e5e02d82c5442e1b007dd65a9240b5f959efe75 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adam.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace optim { + +struct TORCH_API AdamOptions : public OptimizerCloneableOptions { + AdamOptions(double lr = 1e-3); + TORCH_ARG(double, lr) = 1e-3; + typedef std::tuple betas_t; + TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999); + TORCH_ARG(double, eps) = 1e-8; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(bool, amsgrad) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamOptions& lhs, + const AdamOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API AdamParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, step) = 0; + TORCH_ARG(torch::Tensor, exp_avg); + TORCH_ARG(torch::Tensor, exp_avg_sq); + TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {}; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamParamState& lhs, + const AdamParamState& rhs); +}; + +class TORCH_API Adam : public Optimizer { + public: + explicit Adam( + std::vector param_groups, + AdamOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + auto betas = defaults.betas(); + TORCH_CHECK( + 0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, + "Invalid beta parameter at index 0: ", + std::get<0>(betas)); + TORCH_CHECK( + 0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, + "Invalid beta parameter at index 1: ", + std::get<1>(betas)); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + } + explicit Adam(std::vector params, AdamOptions defaults = {}) + : Adam({OptimizerParamGroup(std::move(params))}, defaults) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Adam); + } +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adamw.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adamw.h new file mode 100644 index 0000000000000000000000000000000000000000..a63d7fc32d455425fbb6967534e72c36ac2830c8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/adamw.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace optim { + +struct TORCH_API AdamWOptions : public OptimizerCloneableOptions { + AdamWOptions(double lr = 1e-3); + TORCH_ARG(double, lr) = 1e-3; + typedef std::tuple betas_t; + TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999); + TORCH_ARG(double, eps) = 1e-8; + TORCH_ARG(double, weight_decay) = 1e-2; + TORCH_ARG(bool, amsgrad) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamWOptions& lhs, + const AdamWOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API AdamWParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, step) = 0; + TORCH_ARG(torch::Tensor, exp_avg); + TORCH_ARG(torch::Tensor, exp_avg_sq); + TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {}; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamWParamState& lhs, + const AdamWParamState& rhs); +}; + +class TORCH_API AdamW : public Optimizer { + public: + explicit AdamW( + std::vector param_groups, + AdamWOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + auto betas = defaults.betas(); + TORCH_CHECK( + 0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, + "Invalid beta parameter at index 0: ", + std::get<0>(betas)); + TORCH_CHECK( + 0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, + "Invalid beta parameter at index 1: ", + std::get<1>(betas)); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + } + explicit AdamW(std::vector params, AdamWOptions defaults = {}) + : AdamW({OptimizerParamGroup(std::move(params))}, defaults) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(AdamW); + } +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/lbfgs.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/lbfgs.h new file mode 100644 index 0000000000000000000000000000000000000000..0832afff5f8f2026bcdca8dab726673ec2710fb1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/lbfgs.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace optim { + +struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions { + LBFGSOptions(double lr = 1); + TORCH_ARG(double, lr) = 1; + TORCH_ARG(int64_t, max_iter) = 20; + TORCH_ARG(std::optional, max_eval) = std::nullopt; + TORCH_ARG(double, tolerance_grad) = 1e-7; + TORCH_ARG(double, tolerance_change) = 1e-9; + TORCH_ARG(int64_t, history_size) = 100; + TORCH_ARG(std::optional, line_search_fn) = std::nullopt; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const LBFGSOptions& lhs, + const LBFGSOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API LBFGSParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, func_evals) = 0; + TORCH_ARG(int64_t, n_iter) = 0; + TORCH_ARG(double, t) = 0; + TORCH_ARG(double, prev_loss) = 0; + TORCH_ARG(Tensor, d) = {}; + TORCH_ARG(Tensor, H_diag) = {}; + TORCH_ARG(Tensor, prev_flat_grad) = {}; + TORCH_ARG(std::deque, old_dirs); + TORCH_ARG(std::deque, old_stps); + TORCH_ARG(std::deque, ro); + TORCH_ARG(std::optional>, al) = std::nullopt; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const LBFGSParamState& lhs, + const LBFGSParamState& rhs); +}; + +class TORCH_API LBFGS : public Optimizer { + public: + explicit LBFGS( + std::vector param_groups, + LBFGSOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK( + param_groups_.size() == 1, + "LBFGS doesn't support per-parameter options (parameter groups)"); + if (defaults.max_eval() == std::nullopt) { + auto max_eval_val = (defaults.max_iter() * 5) / 4; + static_cast(param_groups_[0].options()) + .max_eval(max_eval_val); + static_cast(*defaults_.get()).max_eval(max_eval_val); + } + _numel_cache = std::nullopt; + } + explicit LBFGS(std::vector params, LBFGSOptions defaults = {}) + : LBFGS({OptimizerParamGroup(std::move(params))}, defaults) {} + + Tensor step(LossClosure closure) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + std::optional _numel_cache; + int64_t _numel(); + Tensor _gather_flat_grad(); + void _add_grad(const double step_size, const Tensor& update); + std::tuple _directional_evaluate( + const LossClosure& closure, + const std::vector& x, + double t, + const Tensor& d); + void _set_param(const std::vector& params_data); + std::vector _clone_param(); + + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(LBFGS); + } +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/optimizer.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..f6599248244a24363b58e786bc7cca6baa3f9672 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/optimizer.h @@ -0,0 +1,219 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +// Forward declarations confuse Doxygen +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace at { +class Tensor; +} // namespace at + +namespace torch { +using at::Tensor; +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch +#endif // DOXYGEN_SHOULD_SKIP_THIS + +namespace torch { +namespace optim { + +class TORCH_API OptimizerParamState { + public: + OptimizerParamState() = default; + OptimizerParamState(const OptimizerParamState&) = default; + OptimizerParamState& operator=(const OptimizerParamState&) = default; + OptimizerParamState(OptimizerParamState&&) noexcept = default; + OptimizerParamState& operator=(OptimizerParamState&&) noexcept = default; + virtual std::unique_ptr clone() const; + virtual void serialize(torch::serialize::InputArchive& archive); + virtual void serialize(torch::serialize::OutputArchive& archive) const; + virtual ~OptimizerParamState() = default; +}; + +template +class OptimizerCloneableParamState : public OptimizerParamState { + std::unique_ptr clone() const override { + return std::make_unique(static_cast(*this)); + } +}; + +class TORCH_API OptimizerOptions { + public: + OptimizerOptions() = default; + OptimizerOptions(const OptimizerOptions&) = default; + OptimizerOptions& operator=(const OptimizerOptions&) = default; + OptimizerOptions(OptimizerOptions&&) noexcept = default; + OptimizerOptions& operator=(OptimizerOptions&&) noexcept = default; + virtual std::unique_ptr clone() const; + virtual void serialize(torch::serialize::InputArchive& archive); + virtual void serialize(torch::serialize::OutputArchive& archive) const; + virtual ~OptimizerOptions() = default; + virtual double get_lr() const; + virtual void set_lr(const double lr); +}; + +template +class OptimizerCloneableOptions : public OptimizerOptions { + private: + std::unique_ptr clone() const override { + return std::make_unique(static_cast(*this)); + } +}; + +/// Stores parameters in the param_group and stores a pointer to the +/// OptimizerOptions +class TORCH_API OptimizerParamGroup { + public: + // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to + // be copy-constructible. + OptimizerParamGroup(const OptimizerParamGroup& param_group) + : params_(param_group.params()), + options_( + param_group.has_options() ? param_group.options().clone() + : nullptr) {} + OptimizerParamGroup(std::vector params) + : params_(std::move(params)) {} + OptimizerParamGroup( + std::vector params, + std::unique_ptr options) + : params_(std::move(params)), options_(std::move(options)) {} + + OptimizerParamGroup& operator=(const OptimizerParamGroup& param_group) = + delete; + bool has_options() const; + OptimizerOptions& options(); + const OptimizerOptions& options() const; + void set_options(std::unique_ptr options); + std::vector& params(); + const std::vector& params() const; + + protected: + std::vector params_; + std::unique_ptr options_; +}; + +class TORCH_API Optimizer { + public: + // The copy constructor is deleted, because the user should use the + // `state_dict` / `load_state_dict` API to copy an optimizer instead. + Optimizer(const Optimizer& optimizer) = delete; + Optimizer(Optimizer&& optimizer) = default; + + explicit Optimizer( + std::vector param_groups, + std::unique_ptr defaults) + : defaults_(std::move(defaults)) { + for (const auto& param_group : param_groups) { + add_param_group(param_group); + } + } + + /// Constructs the `Optimizer` from a vector of parameters. + explicit Optimizer( + std::vector parameters, + std::unique_ptr defaults) + : Optimizer( + {OptimizerParamGroup(std::move(parameters))}, + std::move(defaults)){}; + + /// Adds the given param_group to the optimizer's param_group list. + void add_param_group(const OptimizerParamGroup& param_group); + + virtual ~Optimizer() = default; + + using LossClosure = std::function; + /// A loss function closure, which is expected to return the loss value. + virtual Tensor step(LossClosure closure = nullptr) = 0; + + /// Adds the given vector of parameters to the optimizer's parameter list. + void add_parameters(const std::vector& parameters); + + /// Zeros out the gradients of all parameters. + void zero_grad(bool set_to_none = true); + + /// Provides a const reference to the parameters in the first param_group this + /// optimizer holds. + const std::vector& parameters() const noexcept; + + /// Provides a reference to the parameters in the first param_group this + /// optimizer holds. + std::vector& parameters() noexcept; + + /// Returns the number of parameters referenced by the optimizer. + size_t size() const noexcept; + + OptimizerOptions& defaults() noexcept; + + const OptimizerOptions& defaults() const noexcept; + + /// Provides a reference to the param_groups this optimizer holds. + std::vector& param_groups() noexcept; + + /// Provides a const reference to the param_groups this optimizer holds. + const std::vector& param_groups() const noexcept; + + /// Provides a reference to the state this optimizer holds + ska::flat_hash_map>& + state() noexcept; + + /// Provides a const reference to the state this optimizer holds + const ska::flat_hash_map>& state() + const noexcept; + + /// Serializes the optimizer state into the given `archive`. + virtual void save(serialize::OutputArchive& archive) const; + + /// Deserializes the optimizer state from the given `archive`. + virtual void load(serialize::InputArchive& archive); + + protected: + std::vector param_groups_; + ska::flat_hash_map> state_; + std::unique_ptr defaults_; +}; + +/* How do we decide whether to serialize undefined tensors or + std::nullopt values into the output archive? +Answer: we strictly follow the behavior of Python API. To be more specific: + +For optimizer options: +a) For undefined tensor: currently no tensor is used as an options argument in +Python API, so we don't need to worry about it now. b) For std::nullopt value: +we serialize std::nullopt values into the output archive, to follow the exact +same behavior as Python API. + +For optimizer param state: +a) For undefined tensor: in param state, undefined tensor in C++ impl is +equivalent to missing key in Python impl. Since we don't serialize missing keys +in Python API, we skip undefined tensors when serializing the param state. b) +For std::nullopt value: in param state, std::nullopt value in C++ impl is +equivalent to missing key in Python impl. Since we don't serialize missing keys +in Python API, we skip std::nullopt values when serializing the param state. */ + +/// Serializes an `Optimizer` into an `OutputArchive`. +TORCH_API serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const Optimizer& optimizer); + +/// Deserializes a `Tensor` from an `InputArchive`. +TORCH_API serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + Optimizer& optimizer); + +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/rmsprop.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/rmsprop.h new file mode 100644 index 0000000000000000000000000000000000000000..69a2e27993d5b76165cb268cb0186e632b1e05f1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/rmsprop.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace optim { + +struct TORCH_API RMSpropOptions + : public OptimizerCloneableOptions { + RMSpropOptions(double lr = 1e-2); + TORCH_ARG(double, lr) = 1e-2; + TORCH_ARG(double, alpha) = 0.99; + TORCH_ARG(double, eps) = 1e-8; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(double, momentum) = 0; + TORCH_ARG(bool, centered) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const RMSpropOptions& lhs, + const RMSpropOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API RMSpropParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, step) = 0; + TORCH_ARG(torch::Tensor, square_avg); + TORCH_ARG(torch::Tensor, momentum_buffer) = {}; + TORCH_ARG(torch::Tensor, grad_avg) = {}; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const RMSpropParamState& lhs, + const RMSpropParamState& rhs); +}; + +class TORCH_API RMSprop : public Optimizer { + public: + explicit RMSprop( + std::vector param_groups, + RMSpropOptions defaults = {}) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + TORCH_CHECK( + defaults.momentum() >= 0, + "Invalid momentum value: ", + defaults.momentum()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + defaults.alpha() >= 0, "Invalid alpha value: ", defaults.alpha()); + } + + explicit RMSprop(std::vector params, RMSpropOptions defaults = {}) + : RMSprop({OptimizerParamGroup(std::move(params))}, defaults) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(RMSprop); + } +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..26d324fbecce166c19e315ab41142b5e9e4cf4de --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include + +namespace torch { +namespace optim { + +class TORCH_API LRScheduler { + public: + // This class needs to take a reference of an optimizer from outside such that + // it can modify its learning rates; due to this the lifetime of said + // optimizer must be maintained + LRScheduler(torch::optim::Optimizer& optimizer); + + virtual ~LRScheduler() = default; + + void step(); + + protected: + // A vector of learning rates is calculated and returned from the specific + // subclass. A vector is returned with each element being a separate learning + // rate for each param group - although the normal use case would be to return + // a vector of identical elements. + virtual std::vector get_lrs() = 0; + + // Get current learning rates from the optimizer + std::vector get_current_lrs() const; + + unsigned step_count_{}; + + private: + void set_optimizer_lrs(const std::vector& learning_rates); + + torch::optim::Optimizer& optimizer_; +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..ae8892ff4fda6b129178f092cef0f9a0892c24c7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include + +#include + +#include + +#include + +namespace torch { +namespace optim { + +class TORCH_API ReduceLROnPlateauScheduler { + public: + enum SchedulerMode { min, max }; + enum ThresholdMode { rel, abs }; + ReduceLROnPlateauScheduler( + Optimizer& optimizer, + SchedulerMode mode = min, + float factor = 0.1, + int patience = 10, + double threshold = 1e-4, + ThresholdMode threshold_mode = rel, + int cooldown = 0, + const std::vector& min_lr = std::vector(), + double eps = 1e-8, + bool verbose = false); + + virtual ~ReduceLROnPlateauScheduler() = default; + + void step(float metric); + + private: + void reset(); + void reduce_lr(int epoch); + bool in_cooldown(); + bool is_better(float a); + void init_is_better( + SchedulerMode mode, + double threshold, + ThresholdMode threshold_mode); + + Optimizer& optimizer; + SchedulerMode mode; + float mode_worse; + float factor; + int patience; + double threshold; + ThresholdMode threshold_mode; + int cooldown; + int cooldown_counter; + std::vector min_lrs; + double eps; + float best; + bool verbose; + int last_epoch; + int num_bad_epochs; +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/step_lr.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/step_lr.h new file mode 100644 index 0000000000000000000000000000000000000000..289bb4bd84e54e995bfc6581aa0c76724661c7ca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/step_lr.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace torch { +namespace optim { + +class TORCH_API StepLR : public LRScheduler { + public: + StepLR( + torch::optim::Optimizer& optimizer, + const unsigned step_size, + const double gamma = 0.1); + + private: + std::vector get_lrs() override; + + const unsigned step_size_; + const double gamma_; +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/serialize.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..7c34450999b6215a40131699d028c55052a83e50 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/serialize.h @@ -0,0 +1,315 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace optim { +namespace detail { +// Utility function to save state +template +void serialize( + serialize::OutputArchive& archive, + const ska::flat_hash_map>& + state) { + for (const auto& item : state) { + serialize::OutputArchive param_state_archive(archive.compilation_unit()); + std::string tensorimpl_key = + std::to_string(reinterpret_cast(item.first)); + const DerivedOptimizerParamState& curr_state = + static_cast(*(item.second.get())); + curr_state.serialize(param_state_archive); + archive.write(tensorimpl_key, param_state_archive); + } +} + +// Utility function to load state +template +void serialize( + serialize::InputArchive& archive, + ska::flat_hash_map>& state) { + std::vector tensorimpl_keys = archive.keys(); + for (const std::string& tensorimpl_key : tensorimpl_keys) { + serialize::InputArchive param_state_archive; + archive.read(tensorimpl_key, param_state_archive); + DerivedOptimizerParamState param_state; + param_state.serialize(param_state_archive); + state[reinterpret_cast(std::stoull(tensorimpl_key))] = + std::make_unique(param_state); + } +} + +// Utility function to save param_groups +template +void serialize( + serialize::OutputArchive& archive, + const std::vector& param_groups) { + archive.write( + "param_groups/size", + torch::tensor(static_cast(param_groups.size()))); + for (const auto i : c10::irange(param_groups.size())) { + serialize::OutputArchive param_group_archive(archive.compilation_unit()); + std::vector params = param_groups[i].params(); + param_group_archive.write( + "params/size", torch::tensor(static_cast(params.size()))); + for (const auto index : c10::irange(params.size())) { + param_group_archive.write( + "params/" + std::to_string(index), + IValue(std::to_string( + reinterpret_cast(params[index].unsafeGetTensorImpl())))); + } + const DerivedOptimizerParamOptions& param_group_options = + static_cast( + param_groups[i].options()); + serialize::OutputArchive param_group_options_archive( + param_group_archive.compilation_unit()); + param_group_options.serialize(param_group_options_archive); + param_group_archive.write("options", param_group_options_archive); + archive.write("param_groups/" + std::to_string(i), param_group_archive); + } +} + +// Utility function to load param_groups +// We take as input vector of pair of string and unique_ptr to optimizer options +// so that we can retain the state for each param by using the old tensor impl +// keys (saved during serialization) and map the new tensor impl keys to the +// correct state for each param +template +void serialize( + serialize::InputArchive& archive, + std::vector< + std::pair, std::unique_ptr>>& + param_groups) { + torch::Tensor param_groups_size_tensor; + archive.read("param_groups/size", param_groups_size_tensor); + const int64_t param_groups_size = param_groups_size_tensor.item(); + for (const auto i : c10::irange(param_groups_size)) { + serialize::InputArchive param_group_archive; + archive.read("param_groups/" + std::to_string(i), param_group_archive); + torch::Tensor size_tensor; + param_group_archive.read("params/size", size_tensor); + const int64_t size = size_tensor.item(); + std::vector params; + for (const auto index : c10::irange(size)) { + IValue ivalue; + param_group_archive.read("params/" + std::to_string(index), ivalue); + std::string element = ivalue.toStringRef(); + params.emplace_back(element); + } + serialize::InputArchive param_group_options_archive; + param_group_archive.read("options", param_group_options_archive); + DerivedOptimizerParamOptions param_group_options(0); + param_group_options.serialize(param_group_options_archive); + param_groups.emplace_back(std::make_pair( + params, + std::make_unique(param_group_options))); + } +} +} // namespace detail + +// Note: These functions are all called `serialize()` so they can be called +// inside a template where the archive type is a template type and can thus be +// passed such that the appropriate overload is selected. + +/// Utility function to save a value of `int64_t` type. +void serialize( + serialize::OutputArchive& archive, + const std::string& key, + const int64_t& value); + +/// Utility function to load a value of `int64_t` type. +void serialize( + serialize::InputArchive& archive, + const std::string& key, + int64_t& value); + +/// Utility function to save a vector of step buffers. +void serialize( + serialize::OutputArchive& archive, + const std::string& key, + const std::vector& steps); + +/// Utility function to load a vector of step buffers. +void serialize( + serialize::InputArchive& archive, + const std::string& key, + std::vector& steps); + +// Utility function to save state and param_groups +template < + typename DerivedOptimizerParamState, + typename DerivedOptimizerParamOptions> +void serialize(serialize::OutputArchive& archive, const Optimizer& optimizer) { + archive.write("pytorch_version", IValue("1.5.0")); + serialize::OutputArchive state_archive(archive.compilation_unit()); + detail::serialize( + state_archive, optimizer.state()); + archive.write("state", state_archive); + + serialize::OutputArchive param_groups_archive(archive.compilation_unit()); + detail::serialize( + param_groups_archive, optimizer.param_groups()); + archive.write("param_groups", param_groups_archive); +} + +// Utility function to load state and param_groups and update state +template < + typename DerivedOptimizerParamState, + typename DerivedOptimizerParamOptions> +void serialize(serialize::InputArchive& archive, Optimizer& optimizer) { + IValue pytorch_version; + archive.read("pytorch_version", pytorch_version); + TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0"); + serialize::InputArchive state_archive; + archive.read("state", state_archive); + ska::flat_hash_map> saved_state; + detail::serialize(state_archive, saved_state); + + serialize::InputArchive param_groups_archive; + archive.read("param_groups", param_groups_archive); + std::vector< + std::pair, std::unique_ptr>> + saved_param_groups; + detail::serialize( + param_groups_archive, saved_param_groups); + + // update state and optimizer options + TORCH_CHECK( + saved_param_groups.size() == optimizer.param_groups().size(), + "loaded state dict has a different number of parameter groups"); + for (const auto i : c10::irange(saved_param_groups.size())) { + std::vector param_group_old_keys = saved_param_groups[i].first; + std::vector params = optimizer.param_groups()[i].params(); + TORCH_CHECK( + param_group_old_keys.size() == params.size(), + "loaded state dict contains a parameter group that has a different size than the optimizer's parameter group"); + + for (const auto idx : c10::irange(params.size())) { + auto param_group_old_key = + reinterpret_cast(std::stoull(param_group_old_keys[idx])); + if (saved_state.find(param_group_old_key) != saved_state.end()) { + optimizer.state()[params[idx].unsafeGetTensorImpl()] = + std::move(saved_state[param_group_old_key]); + } + } + + auto& saved_options = reinterpret_cast( + *saved_param_groups[i].second); + auto& current_options = reinterpret_cast( + optimizer.param_groups()[i].options()); + current_options = saved_options; + } +} + +/// Utility function to save a vector of buffers. +template +void serialize( + serialize::OutputArchive& archive, + const std::string& key, + const BufferContainer& buffers) { + archive.write( + key + "/size", torch::tensor(static_cast(buffers.size()))); + for (const auto index : c10::irange(buffers.size())) { + archive.write( + key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true); + } +} + +/// Utility function to load a vector of buffers. +template +void serialize( + serialize::InputArchive& archive, + const std::string& key, + BufferContainer& buffers) { + buffers.clear(); + torch::Tensor size_tensor; + archive.read(key + "/size", size_tensor); + const size_t size = size_tensor.item(); + for (const auto index : c10::irange(size)) { + buffers.emplace_back(); + archive.read( + key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true); + } +} + +template +c10::List deque_to_list(const std::deque& dq) { + c10::List list; + list.reserve(dq.size()); + for (const auto& e : dq) { + list.emplace_back(e); + } + return list; +} + +template +std::deque list_to_deque(const c10::List& list) { + std::deque dq; + for (const auto& e : list) { + dq.emplace_back(e); + } + return dq; +} + +#define _TORCH_OPTIM_SERIALIZE(name) \ + torch::optim::serialize(archive, #name, self.name) + +#define _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(OptimizerName) \ + torch::optim::serialize( \ + archive, self) + +#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG(name) \ + { \ + auto ivalue = torch::IValue(name()); \ + /* do not serialize if name is an undefined tensor*/ \ + if (!(ivalue.isTensor() && \ + ivalue.unsafeToTensorImpl() == \ + at::UndefinedTensorImpl::singleton())) { \ + archive.write(#name, ivalue); \ + } \ + } + +#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(name) \ + { \ + c10::IValue ivalue = torch::IValue(deque_to_list(name())); \ + archive.write(#name, ivalue); \ + } + +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) \ + { \ + c10::IValue ivalue; \ + bool exists = archive.try_read(#name, ivalue); \ + if (exists) { \ + name(ivalue.to()); \ + } else { \ + bool is_tensor_type = std::is_base_of::value; \ + TORCH_INTERNAL_ASSERT(is_tensor_type); \ + } \ + } + +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) \ + { \ + c10::IValue ivalue; \ + bool exists = archive.try_read(#name, ivalue); \ + if (exists) { \ + name(ivalue.toOptional()); \ + } \ + } + +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(T, name) \ + { \ + c10::IValue ivalue; \ + archive.read(#name, ivalue); \ + auto list = ivalue.to>(); \ + name(list_to_deque(list)); \ + } + +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/sgd.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/sgd.h new file mode 100644 index 0000000000000000000000000000000000000000..85e9aba7ba48f751d0ae00f8356ca8d47d7b0ad2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/optim/sgd.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch { +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch + +namespace torch { +namespace optim { + +struct TORCH_API SGDOptions : public OptimizerCloneableOptions { + SGDOptions(double lr); + TORCH_ARG(double, lr); + TORCH_ARG(double, momentum) = 0; + TORCH_ARG(double, dampening) = 0; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(bool, nesterov) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const SGDOptions& lhs, + const SGDOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API SGDParamState + : public OptimizerCloneableParamState { + TORCH_ARG(torch::Tensor, momentum_buffer); + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const SGDParamState& lhs, + const SGDParamState& rhs); +}; + +class TORCH_API SGD : public Optimizer { + public: + explicit SGD( + std::vector param_groups, + SGDOptions defaults) + : Optimizer( + std::move(param_groups), + std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK( + defaults.momentum() >= 0, + "Invalid momentum value: ", + defaults.momentum()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + !defaults.nesterov() || + (defaults.momentum() > 0 && defaults.dampening() == 0), + "Nesterov momentum requires a momentum and zero dampening"); + } + + explicit SGD(std::vector params, SGDOptions defaults) + : SGD({OptimizerParamGroup(std::move(params))}, defaults) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(SGD); + } +}; +} // namespace optim +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/archive.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/archive.h new file mode 100644 index 0000000000000000000000000000000000000000..d4ebe8e9d54cc127dd2df4ad1ccbcd226b037326 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/archive.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/input-archive.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/input-archive.h new file mode 100644 index 0000000000000000000000000000000000000000..3650cfcfea23f9b7ebac6afc726025df5953342e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/input-archive.h @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { +class Tensor; +} // namespace at + +namespace torch { +using at::Tensor; +namespace jit { +struct Module; +} // namespace jit +} // namespace torch + +namespace torch { +namespace serialize { + +/// A recursive representation of tensors that can be deserialized from a file +/// or stream. In most cases, users should not have to interact with this class, +/// and should instead use `torch::load`. +class TORCH_API InputArchive final { + public: + /// Default-constructs the `InputArchive`. + InputArchive(); + + // Move is allowed. + InputArchive(InputArchive&&) = default; + InputArchive& operator=(InputArchive&&) = default; + + // Copy is disallowed. + InputArchive(InputArchive&) = delete; + InputArchive& operator=(InputArchive&) = delete; + + ~InputArchive() = default; + + /// Reads an `IValue` associated with a given `key`. + void read(const std::string& key, c10::IValue& ivalue); + + /// Reads an `IValue` associated with a given `key`. If there is no `IValue` + /// associated with the `key`, this returns false, otherwise it returns true. + bool try_read(const std::string& key, c10::IValue& ivalue); + + /// Reads a `tensor` associated with a given `key`. If there is no `tensor` + /// associated with the `key`, this returns false, otherwise it returns true. + /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` + /// must be `true`. + bool try_read(const std::string& key, Tensor& tensor, bool is_buffer = false); + + /// Reads a `tensor` associated with a given `key`. + /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` + /// must be `true`. + void read(const std::string& key, Tensor& tensor, bool is_buffer = false); + + /// Reads a `InputArchive` associated with a given `key`. If there is no + /// `InputArchive` associated with the `key`, this returns false, otherwise + /// it returns true. + bool try_read(const std::string& key, InputArchive& archive); + + /// Reads an `InputArchive` associated with a given `key`. + /// The archive can thereafter be used for further deserialization of the + /// nested data. + void read(const std::string& key, InputArchive& archive); + + /// Loads the `InputArchive` from a serialized representation stored in the + /// file at `filename`. Storage are remapped using device option. If device + /// is not specified, the module is loaded to the original device. + void load_from( + const std::string& filename, + std::optional device = std::nullopt); + + /// Loads the `InputArchive` from a serialized representation stored in the + /// given `stream`. Storage are remapped using device option. If device + /// is not specified, the module is loaded to the original device. + void load_from( + std::istream& stream, + std::optional device = std::nullopt); + + // Loads given the specified flat array. + void load_from( + const char* data, + size_t size, + std::optional device = std::nullopt); + + // Loads given the specified read and size functions. + void load_from( + const std::function& + read_func, + const std::function& size_func, + std::optional device = std::nullopt); + + // Returns the vector of keys in the input archive. + std::vector keys(); + + /// Forwards all arguments to `read()`. + /// Useful for generic code that can be re-used for both `InputArchive` and + /// `OutputArchive` (where `operator()` forwards to `write()`). + template + void operator()(Ts&&... ts) { + read(std::forward(ts)...); + } + + private: + jit::Module module_; + std::string hierarchy_prefix_; +}; +} // namespace serialize +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/output-archive.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/output-archive.h new file mode 100644 index 0000000000000000000000000000000000000000..12e0f54971cb3912fb5a54334e2d4f6ac06d3022 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/output-archive.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace at { +class Tensor; +} // namespace at + +namespace torch { +using at::Tensor; +namespace jit { +struct Module; +} // namespace jit +} // namespace torch + +namespace torch { +namespace serialize { +class TORCH_API OutputArchive final { + public: + explicit OutputArchive(std::shared_ptr cu); + explicit OutputArchive() + : cu_(std::make_shared()), + module_("__torch__.Module", cu_) {} + + // Move is allowed. + OutputArchive(OutputArchive&&) = default; + OutputArchive& operator=(OutputArchive&&) = default; + + // Copy is disallowed. + OutputArchive(OutputArchive&) = delete; + OutputArchive& operator=(OutputArchive&) = delete; + + std::shared_ptr compilation_unit() const { + return cu_; + } + + /// Writes an `IValue` to the `OutputArchive`. + void write(const std::string& key, const c10::IValue& ivalue); + + /// Writes a `(key, tensor)` pair to the `OutputArchive`, and marks it as + /// being or not being a buffer (non-differentiable tensor). + void write( + const std::string& key, + const Tensor& tensor, + bool is_buffer = false); + + /// Writes a nested `OutputArchive` under the given `key` to this + /// `OutputArchive`. + void write(const std::string& key, OutputArchive& nested_archive); + + /// Saves the `OutputArchive` into a serialized representation in a file at + /// `filename`. + void save_to(const std::string& filename); + + /// Saves the `OutputArchive` into a serialized representation into the given + /// `stream`. + void save_to(std::ostream& stream); + + /// Saves the `OutputArchive` into a serialized representation using the + /// given writer function. + void save_to(const std::function& func); + + /// Forwards all arguments to `write()`. + /// Useful for generic code that can be re-used for both `OutputArchive` and + /// `InputArchive` (where `operator()` forwards to `read()`). + template + void operator()(Ts&&... ts) { + write(std::forward(ts)...); + } + + private: + std::shared_ptr cu_; + jit::Module module_; +}; +} // namespace serialize +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/tensor.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..9f77ed170db32a497c23feed05aae8c266ab282e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/serialize/tensor.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace torch { +inline serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const Tensor& tensor) { + archive.write("0", tensor); + return archive; +} + +inline serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + Tensor& tensor) { + archive.read("0", tensor); + return archive; +} +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/pycfunction_helpers.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/pycfunction_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..745e1842e682c8a2fb3cc9d94e77122505016571 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/pycfunction_helpers.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +#include + +inline PyCFunction castPyCFunctionWithKeywords(PyCFunctionWithKeywords func) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type") + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type-strict") + return reinterpret_cast(func); + C10_DIAGNOSTIC_POP() + C10_DIAGNOSTIC_POP() +} diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/python_numbers.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/python_numbers.h new file mode 100644 index 0000000000000000000000000000000000000000..d5b772b768e223e1af8705e118f7a56ccd87c39b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/python_numbers.h @@ -0,0 +1,205 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// largest integer that can be represented consecutively in a double +const int64_t DOUBLE_INT_MAX = 9007199254740992; + +inline PyObject* THPUtils_packDeviceIndex(c10::DeviceIndex value) { + return PyLong_FromLong(value); +} + +inline PyObject* THPUtils_packInt32(int32_t value) { + return PyLong_FromLong(value); +} + +inline PyObject* THPUtils_packInt64(int64_t value) { + return PyLong_FromLongLong(value); +} + +inline PyObject* THPUtils_packUInt32(uint32_t value) { + return PyLong_FromUnsignedLong(value); +} + +inline PyObject* THPUtils_packUInt64(uint64_t value) { + return PyLong_FromUnsignedLongLong(value); +} + +inline PyObject* THPUtils_packDoubleAsInt(double value) { + return PyLong_FromDouble(value); +} + +inline bool THPUtils_checkLongExact(PyObject* obj) { + return PyLong_CheckExact(obj) && !PyBool_Check(obj); +} + +inline bool THPUtils_checkLong(PyObject* obj) { + // Fast path + if (THPUtils_checkLongExact(obj)) { + return true; + } + +#ifdef USE_NUMPY + if (torch::utils::is_numpy_int(obj)) { + return true; + } +#endif + + return PyLong_Check(obj) && !PyBool_Check(obj); +} + +inline int32_t THPUtils_unpackInt(PyObject* obj) { + int overflow = 0; + long value = PyLong_AsLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking long"); + } + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + throw std::runtime_error("Overflow when unpacking long"); + } + return (int32_t)value; +} + +inline int64_t THPUtils_unpackLong(PyObject* obj) { + int overflow = 0; + long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking long"); + } + return (int64_t)value; +} + +inline uint32_t THPUtils_unpackUInt32(PyObject* obj) { + unsigned long value = PyLong_AsUnsignedLong(obj); + if (PyErr_Occurred()) { + throw python_error(); + } + if (value > std::numeric_limits::max()) { + throw std::runtime_error("Overflow when unpacking unsigned long"); + } + return (uint32_t)value; +} + +inline uint64_t THPUtils_unpackUInt64(PyObject* obj) { + unsigned long long value = PyLong_AsUnsignedLongLong(obj); + if (PyErr_Occurred()) { + throw python_error(); + } + return (uint64_t)value; +} + +bool THPUtils_checkIndex(PyObject* obj); + +inline int64_t THPUtils_unpackIndex(PyObject* obj) { + if (!THPUtils_checkLong(obj)) { + auto index = THPObjectPtr(PyNumber_Index(obj)); + if (index == nullptr) { + throw python_error(); + } + // NB: This needs to be called before `index` goes out of scope and the + // underlying object's refcount is decremented + return THPUtils_unpackLong(index.get()); + } + return THPUtils_unpackLong(obj); +} + +inline bool THPUtils_unpackBool(PyObject* obj) { + if (obj == Py_True) { + return true; + } else if (obj == Py_False) { + return false; + } else { + throw std::runtime_error("couldn't convert python object to boolean"); + } +} + +inline bool THPUtils_checkBool(PyObject* obj) { +#ifdef USE_NUMPY + if (torch::utils::is_numpy_bool(obj)) { + return true; + } +#endif + return PyBool_Check(obj); +} + +inline bool THPUtils_checkDouble(PyObject* obj) { +#ifdef USE_NUMPY + if (torch::utils::is_numpy_scalar(obj)) { + return true; + } +#endif + return PyFloat_Check(obj) || PyLong_Check(obj); +} + +inline double THPUtils_unpackDouble(PyObject* obj) { + if (PyFloat_Check(obj)) { + return PyFloat_AS_DOUBLE(obj); + } + double value = PyFloat_AsDouble(obj); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + return value; +} + +inline c10::complex THPUtils_unpackComplexDouble(PyObject* obj) { + Py_complex value = PyComplex_AsCComplex(obj); + if (value.real == -1.0 && PyErr_Occurred()) { + throw python_error(); + } + + return c10::complex(value.real, value.imag); +} + +inline bool THPUtils_unpackNumberAsBool(PyObject* obj) { + if (PyFloat_Check(obj)) { + return (bool)PyFloat_AS_DOUBLE(obj); + } + + if (PyComplex_Check(obj)) { + double real_val = PyComplex_RealAsDouble(obj); + double imag_val = PyComplex_ImagAsDouble(obj); + return !(real_val == 0 && imag_val == 0); + } + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int overflow; + long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + // No need to check overflow, because when overflow occured, it should + // return true in order to keep the same behavior of numpy. + return (bool)value; +} + +inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) { + int overflow = 0; + long value = PyLong_AsLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking DeviceIndex"); + } + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + throw std::runtime_error("Overflow when unpacking DeviceIndex"); + } + return (c10::DeviceIndex)value; +} diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/python_strings.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/python_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..cca161399c447037c05c1d306e195912cde82c56 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/python_strings.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Utilities for handling Python strings. Note that PyString, when defined, is +// the same as PyBytes. + +// Returns true if obj is a bytes/str or unicode object +// As of Python 3.6, this does not require the GIL +inline bool THPUtils_checkString(PyObject* obj) { + return PyBytes_Check(obj) || PyUnicode_Check(obj); +} + +// Unpacks PyBytes (PyString) or PyUnicode as std::string +// PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8. +// NOTE: this method requires the GIL +inline std::string THPUtils_unpackString(PyObject* obj) { + if (PyBytes_Check(obj)) { + size_t size = PyBytes_GET_SIZE(obj); + return std::string(PyBytes_AS_STRING(obj), size); + } + if (PyUnicode_Check(obj)) { + Py_ssize_t size = 0; + const char* data = PyUnicode_AsUTF8AndSize(obj, &size); + if (!data) { + throw std::runtime_error("error unpacking string as utf-8"); + } + return std::string(data, (size_t)size); + } + throw std::runtime_error("unpackString: expected bytes or unicode object"); +} + +// Unpacks PyBytes (PyString) or PyUnicode as c10::string_view +// PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8. +// NOTE: If `obj` is destroyed, then the non-owning c10::string_view will +// become invalid. If the string needs to be accessed at any point after +// `obj` is destroyed, then the c10::string_view should be copied into +// a std::string, or another owning object, and kept alive. For an example, +// look at how IValue and autograd nodes handle c10::string_view arguments. +// NOTE: this method requires the GIL +inline c10::string_view THPUtils_unpackStringView(PyObject* obj) { + if (PyBytes_Check(obj)) { + size_t size = PyBytes_GET_SIZE(obj); + return c10::string_view(PyBytes_AS_STRING(obj), size); + } + if (PyUnicode_Check(obj)) { + Py_ssize_t size = 0; + const char* data = PyUnicode_AsUTF8AndSize(obj, &size); + if (!data) { + throw std::runtime_error("error unpacking string as utf-8"); + } + return c10::string_view(data, (size_t)size); + } + throw std::runtime_error("unpackString: expected bytes or unicode object"); +} + +inline PyObject* THPUtils_packString(const char* str) { + return PyUnicode_FromString(str); +} + +inline PyObject* THPUtils_packString(const std::string& str) { + return PyUnicode_FromStringAndSize(str.c_str(), str.size()); +} + +inline PyObject* THPUtils_internString(const std::string& str) { + return PyUnicode_InternFromString(str.c_str()); +} + +// Precondition: THPUtils_checkString(obj) must be true +inline bool THPUtils_isInterned(PyObject* obj) { + return PyUnicode_CHECK_INTERNED(obj); +} + +// Precondition: THPUtils_checkString(obj) must be true +inline void THPUtils_internStringInPlace(PyObject** obj) { + PyUnicode_InternInPlace(obj); +} + +/* + * Reference: + * https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 + * + * Stripped down version of PyObject_GetAttrString, + * avoids lookups for None, tuple, and List objects, + * and doesn't create a PyErr since this code ignores it. + * + * This can be much faster then PyObject_GetAttrString where + * exceptions are not used by caller. + * + * 'obj' is the object to search for attribute. + * + * 'name' is the attribute to search for. + * + * Returns a py::object wrapping the return value. If the attribute lookup + * failed the value will be NULL. + * + */ + +inline py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { + PyTypeObject* tp = Py_TYPE(obj); + PyObject* res = (PyObject*)nullptr; + + /* Attribute referenced by (char *)name */ + if (tp->tp_getattr != nullptr) { + // This is OK per https://bugs.python.org/issue39620 + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + res = (*tp->tp_getattr)(obj, const_cast(name)); + if (res == nullptr) { + PyErr_Clear(); + } + } + /* Attribute referenced by (PyObject *)name */ + else if (tp->tp_getattro != nullptr) { + auto w = py::reinterpret_steal(THPUtils_internString(name)); + if (w.ptr() == nullptr) { + return py::object(); + } + res = (*tp->tp_getattro)(obj, w.ptr()); + if (res == nullptr) { + PyErr_Clear(); + } + } + return py::reinterpret_steal(res); +} diff --git a/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/tensor_new.h b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/tensor_new.h new file mode 100644 index 0000000000000000000000000000000000000000..088f8d1927c4732d8543ca82a39c08247257066a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/torch/csrc/utils/tensor_new.h @@ -0,0 +1,136 @@ +#pragma once + +#include +#include + +#include + +namespace torch::utils { + +// NOTE: [torch.tensor, lift_fresh, and device movement] +// +// The `only_lift_cpu_tensors` flag controls what happens on torch.tensor([1, 2, +// 3], device="cuda") (or any non-CPU devices). +// +// If false (default): +// - the data gets moved into a CPU Tensor +// - then, it gets moved to cuda (via .to) +// - finally, we call lift_fresh() on it. +// Steps 1 and 2 happen with all modes disabled. +// +// If true: +// - the data gets moved into a CPU Tensor (with correct dtype) +// - we call lift_fresh() on it +// - finally, we move it to cuda (via .to) +// Step 1 happens with all modes disabled. +// +// `only_lift_cpu_tensors=true` is useful to prevent CUDA initialization under +// FakeTensorMode because it avoids moving concrete data to CUDA. +TORCH_API bool only_lift_cpu_tensors(); +TORCH_API void set_only_lift_cpu_tensors(bool value); + +at::Tensor base_tensor_ctor(PyObject* args, PyObject* kwargs); +at::Tensor legacy_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor legacy_tensor_new( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor indexing_tensor_from_data( + c10::TensorOptions options, + at::ScalarType scalar_type, + std::optional device, + PyObject* data); +at::Tensor sparse_coo_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +void _validate_sparse_coo_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); + +at::Tensor sparse_compressed_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_csr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_csc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_bsr_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor sparse_bsc_tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); + +void _validate_sparse_compressed_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_csr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_csc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_bsr_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +void _validate_sparse_bsc_tensor_args( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); + +at::Tensor tensor_ctor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor as_tensor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PythonArgs& r); +at::Tensor new_tensor( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor new_ones( + c10::DispatchKey dispatch_key, + at::ScalarType scalar_type, + PyObject* args, + PyObject* kwargs); +at::Tensor tensor_frombuffer( + PyObject* buffer, + at::ScalarType dtype, + int64_t count, + int64_t offset, + bool requires_grad); +at::Tensor tensor_fromDLPack(PyObject* data); +at::Tensor asarray( + PyObject* obj, + std::optional dtype, + std::optional device, + std::optional copy, + bool requires_grad); +} // namespace torch::utils