multimodal-hw / src /utils.py
AlekMan's picture
Upload 41 files
61d3625 verified
raw
history blame contribute delete
821 Bytes
"""Utility helpers for reproducibility and tensor handling."""
from __future__ import annotations
import random
from contextlib import contextmanager
from typing import Iterable, Iterator, Sequence
import numpy as np
import torch
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
@contextmanager
def torch_no_grad() -> Iterator[None]:
with torch.no_grad():
yield
def batched(iterable: Sequence | Iterable, batch_size: int) -> Iterator[list]:
batch: list = []
for item in iterable:
batch.append(item)
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch