#include "wayy_db/ops/joins.hpp" #include #include #include #include namespace wayy_db::ops { namespace { // Hash combine for multi-key joins struct KeyHash { size_t operator()(const std::vector& key) const { size_t hash = 0; for (auto val : key) { hash ^= std::hash{}(val) + 0x9e3779b9 + (hash << 6) + (hash >> 2); } return hash; } }; // Extract join key values from a row std::vector extract_key(const Table& table, const std::vector& on, size_t row) { std::vector key; key.reserve(on.size()); for (const auto& col_name : on) { const Column& col = table.column(col_name); switch (col.dtype()) { case DType::Int64: case DType::Timestamp: key.push_back(const_cast(col).as_int64()[row]); break; case DType::Symbol: key.push_back(const_cast(col).as_symbol()[row]); break; default: throw InvalidOperation("Join key column must be Int64, Timestamp, or Symbol"); } } return key; } // Group row indices by key values std::unordered_map, std::vector, KeyHash> group_by_key(const Table& table, const std::vector& on) { std::unordered_map, std::vector, KeyHash> groups; for (size_t i = 0; i < table.num_rows(); ++i) { auto key = extract_key(table, on, i); groups[key].push_back(i); } return groups; } } // namespace Table aj(const Table& left, const Table& right, const std::vector& on, const std::string& as_of) { // Validate inputs if (!left.is_sorted() || left.sorted_by() != as_of) { throw InvalidOperation("Left table must be sorted by " + as_of); } if (!right.is_sorted() || right.sorted_by() != as_of) { throw InvalidOperation("Right table must be sorted by " + as_of); } // Group right table by join keys auto right_groups = group_by_key(right, on); // Get timestamp columns auto left_ts = const_cast(left).column(as_of).as_int64(); auto right_ts = const_cast(right).column(as_of).as_int64(); // Result builders - collect matching indices std::vector left_indices; std::vector right_indices; // -1 means no match left_indices.reserve(left.num_rows()); right_indices.reserve(left.num_rows()); // For each left row, find the most recent right row for (size_t i = 0; i < left.num_rows(); ++i) { auto key = extract_key(left, on, i); int64_t ts = left_ts[i]; auto group_it = right_groups.find(key); if (group_it == right_groups.end()) { // No matching key in right table left_indices.push_back(i); right_indices.push_back(static_cast(-1)); continue; } const auto& group = group_it->second; // Binary search for largest timestamp <= ts auto it = std::upper_bound(group.begin(), group.end(), ts, [&right_ts](int64_t t, size_t idx) { return t < right_ts[idx]; }); if (it != group.begin()) { --it; left_indices.push_back(i); right_indices.push_back(*it); } else { // No timestamp <= ts left_indices.push_back(i); right_indices.push_back(static_cast(-1)); } } // Build result table Table result("aj_result"); // Add left columns for (const auto& col_name : left.column_names()) { const Column& src = left.column(col_name); size_t elem_size = dtype_size(src.dtype()); std::vector data(left_indices.size() * elem_size); const uint8_t* src_data = static_cast(src.data()); for (size_t i = 0; i < left_indices.size(); ++i) { std::memcpy(data.data() + i * elem_size, src_data + left_indices[i] * elem_size, elem_size); } result.add_column(Column(col_name, src.dtype(), std::move(data))); } // Add right columns (excluding join keys and as_of) for (const auto& col_name : right.column_names()) { // Skip if already in left or is a join key if (result.has_column(col_name)) continue; if (std::find(on.begin(), on.end(), col_name) != on.end()) continue; const Column& src = right.column(col_name); size_t elem_size = dtype_size(src.dtype()); std::vector data(right_indices.size() * elem_size, 0); const uint8_t* src_data = static_cast(src.data()); for (size_t i = 0; i < right_indices.size(); ++i) { if (right_indices[i] != static_cast(-1)) { std::memcpy(data.data() + i * elem_size, src_data + right_indices[i] * elem_size, elem_size); } // else: leave as zero (null representation) } result.add_column(Column(col_name, src.dtype(), std::move(data))); } result.set_sorted_by(as_of); return result; } Table wj(const Table& left, const Table& right, const std::vector& on, const std::string& as_of, int64_t window_before, int64_t window_after) { // Validate inputs if (!left.is_sorted() || left.sorted_by() != as_of) { throw InvalidOperation("Left table must be sorted by " + as_of); } if (!right.is_sorted() || right.sorted_by() != as_of) { throw InvalidOperation("Right table must be sorted by " + as_of); } // Group right table by join keys auto right_groups = group_by_key(right, on); // Get timestamp columns auto left_ts = const_cast(left).column(as_of).as_int64(); auto right_ts = const_cast(right).column(as_of).as_int64(); // Result builders std::vector left_indices; std::vector right_indices; // For each left row, find all right rows in window for (size_t i = 0; i < left.num_rows(); ++i) { auto key = extract_key(left, on, i); int64_t ts = left_ts[i]; int64_t ts_min = ts - window_before; int64_t ts_max = ts + window_after; auto group_it = right_groups.find(key); if (group_it == right_groups.end()) { continue; // No matching key } const auto& group = group_it->second; // Find range [ts_min, ts_max] auto lower = std::lower_bound(group.begin(), group.end(), ts_min, [&right_ts](size_t idx, int64_t t) { return right_ts[idx] < t; }); auto upper = std::upper_bound(group.begin(), group.end(), ts_max, [&right_ts](int64_t t, size_t idx) { return t < right_ts[idx]; }); for (auto it = lower; it != upper; ++it) { left_indices.push_back(i); right_indices.push_back(*it); } } // Build result table (similar to aj) Table result("wj_result"); // Add left columns for (const auto& col_name : left.column_names()) { const Column& src = left.column(col_name); size_t elem_size = dtype_size(src.dtype()); std::vector data(left_indices.size() * elem_size); const uint8_t* src_data = static_cast(src.data()); for (size_t i = 0; i < left_indices.size(); ++i) { std::memcpy(data.data() + i * elem_size, src_data + left_indices[i] * elem_size, elem_size); } result.add_column(Column(col_name, src.dtype(), std::move(data))); } // Add right columns (excluding join keys) for (const auto& col_name : right.column_names()) { if (result.has_column(col_name)) continue; if (std::find(on.begin(), on.end(), col_name) != on.end()) continue; const Column& src = right.column(col_name); size_t elem_size = dtype_size(src.dtype()); std::vector data(right_indices.size() * elem_size); const uint8_t* src_data = static_cast(src.data()); for (size_t i = 0; i < right_indices.size(); ++i) { std::memcpy(data.data() + i * elem_size, src_data + right_indices[i] * elem_size, elem_size); } result.add_column(Column(col_name, src.dtype(), std::move(data))); } if (!result.column_names().empty()) { result.set_sorted_by(as_of); } return result; } Table inner_join(const Table& left, const Table& right, const std::vector& on) { // TODO: Implement inner join throw InvalidOperation("inner_join not yet implemented"); } Table left_join(const Table& left, const Table& right, const std::vector& on) { // TODO: Implement left join throw InvalidOperation("left_join not yet implemented"); } } // namespace wayy_db::ops