stable-diffusion-2-1-ait / AutoencoderKL /constant_folder-generated.h
yuvalkirstain's picture
add clip
17db41a
#pragma once
#include "logging.h"
#include "device_functions-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include "model.h"
#include "macros.h"
#include <algorithm>
#include <deque>
#include <fstream>
#include <iostream>
#include <string>
#include <unordered_map>
#include <math.h>
namespace ait {
// Model is the class that actually performs inference. It owns memory for
// intermediate tensors and dynamic dimensions. Constants are owned by
// the model's owning container object, and input/output memory is owned
// by the user.
// Once an inference run has started, it is not safe to re-use the Model
// until the run has finished!
class ConstantFolder : public ModelBase<ConstantFolder> {
public:
ConstantFolder(
size_t blob_size,
size_t workspace_size,
size_t unique_workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
uint8_t* constants,
AITemplateAllocator& allocator)
: ModelBase(
blob_size,
workspace_size,
unique_workspace_size,
num_inputs,
num_outputs,
num_unbound_constants,
constants,
allocator) {
auto* blob_ptr = static_cast<uint8_t*>(blob_.get());
}
void SetUpInputsOutputs() {
}
void ResetConstants(uint8_t* constants) {
/*
* This can be called if we want to use a different piece of memory
* for the constants to be consumed.
*/
}
void DeviceToDeviceCopies(StreamType stream) {
}
void RunImpl(StreamType stream) {
DeviceToDeviceCopies(stream);
}
void ProfileImpl(StreamType stream, size_t iters, const std::string& filename) {
std::ofstream ss(filename);
if (!ss) {
throw std::runtime_error(std::string("Could not open file ") + filename);
}
ss << "{\n";
ss << "}\n";
DeviceToDeviceCopies(stream);
std::cout << "AIT per op profiling finished." << std::endl;
}
static std::unique_ptr<ConstantFolder> Create(
AITemplateAllocator& allocator,
uint8_t* constants
) {
return std::make_unique<ConstantFolder>(
0,
0,
0,
0,
0,
0,
constants,
allocator
);
}
private:
};
} // namespace ait