File size: 5,386 Bytes
7daf628 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | """Misc utils."""
import os
from shared.utils.log import tqdm_iterator
import numpy as np
from termcolor import colored
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
class DictToObj:
def __init__(self, dictionary):
for key, value in dictionary.items():
if isinstance(value, dict):
# Recursively turn dictionaries into DictToObj instances
setattr(self, key, DictToObj(value))
else:
setattr(self, key, value)
def __repr__(self):
return str(self.__dict__)
def ignore_warnings(type="ignore"):
import warnings
warnings.filterwarnings(type)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def download_youtube_video(youtube_id, ext='mp4', resolution="360p", **kwargs):
import pytube
video_url = f"https://www.youtube.com/watch?v={youtube_id}"
yt = pytube.YouTube(video_url)
try:
streams = yt.streams.filter(
file_extension=ext, res=resolution, progressive=True, **kwargs,
)
# streams[0].download(output_path=save_dir, filename=f"{video_id}.{ext}")
streams[0].download(output_path='/tmp', filename='sample.mp4')
except:
print("Failed to download video: ", video_url)
return None
return "/tmp/sample.mp4"
def check_audio(video_path):
from moviepy.video.io.VideoFileClip import VideoFileClip
try:
return VideoFileClip(video_path).audio is not None
except:
return False
def check_audio_multiple(video_paths, n_jobs=8):
"""Parallelly check if videos have audio"""
iterator = tqdm_iterator(video_paths, desc="Checking audio")
from joblib import Parallel, delayed
return Parallel(n_jobs=n_jobs)(
delayed(check_audio)(video_path) for video_path in iterator
)
def num_trainable_params(model, round=3, verbose=True, return_count=False):
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
model_name = model.__class__.__name__
if round is not None:
value = np.round(n_params / 1e6, round)
unit = "M"
else:
value = n_params
unit = ""
if verbose:
print(f"::: Number of trainable parameters in {model_name}: {value} {unit}")
if return_count:
return n_params
def num_params(model, round=3):
n_params = sum([p.numel() for p in model.parameters()])
model_name = model.__class__.__name__
if round is not None:
value = np.round(n_params / 1e6, round)
unit = "M"
else:
value = n_params
unit = ""
print(f"::: Number of total parameters in {model_name}: {value}{unit}")
def fix_seed(seed=42):
"""Fix all numpy/pytorch/random seeds."""
import random
import torch
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def check_tensor(x):
print(x.shape, x.min(), x.max())
import hashlib
def encode_string(input_string, num_chars=4):
"""
Encodes a given string into a 4-character code using the SHA-256 hash algorithm.
Args:
input_string (str): The input string to be encoded.
num_chars (int): The number of characters to take from the hash digest.
Returns:
str: A 4-character code representing the encoded string.
"""
# Convert the input string to bytes
input_bytes = input_string.encode('utf-8')
# Calculate the SHA-256 hash of the input bytes
hash_object = hashlib.sha256(input_bytes)
# Get the hexadecimal digest of the hash
hex_digest = hash_object.hexdigest()
# Take the first 4 characters of the hexadecimal digest as the code
code = hex_digest[:num_chars]
return code
def flatten_list_of_lists(xss):
return [x for xs in xss for x in xs]
import textwrap
def get_terminal_width():
import shutil
return shutil.get_terminal_size().columns
def wrap_text(text: str, max_length: int = 100) -> str:
"""
Wraps a long string to the specified max_length for easier printing.
Args:
text (str): The input string to wrap.
max_length (int): The maximum length of each line. Default is 80.
Returns:
str: The wrapped text with lines at most max_length long.
"""
terminal_width = get_terminal_width()
max_length = min(max_length, terminal_width)
wrapped_text = textwrap.fill(text, width=max_length)
return wrapped_text
def print_colored(heading, text, color="blue", warp=True):
width = get_terminal_width()
print(colored(heading + "." * (width - len(heading)), color))
if warp:
text = text.split("\n")
text = [wrap_text(t) for t in text]
text = "\n".join(text)
print(text)
# print(wrap_text(text))
else:
print(text)
print("." * width)
def in_notebook():
try:
from IPython import get_ipython
if 'IPKernelApp' not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def get_run_id():
import datetime
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|