Safetensors
English
llava
video-retrieval
text-to-video-search
multimodal-embedding
TARA / shared /utils /misc.py
bpiyush's picture
Update TARA to latest Tarsier2 checkpoint and runnable demo.
7daf628
"""Misc utils."""
import os
from shared.utils.log import tqdm_iterator
import numpy as np
from termcolor import colored
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
class DictToObj:
def __init__(self, dictionary):
for key, value in dictionary.items():
if isinstance(value, dict):
# Recursively turn dictionaries into DictToObj instances
setattr(self, key, DictToObj(value))
else:
setattr(self, key, value)
def __repr__(self):
return str(self.__dict__)
def ignore_warnings(type="ignore"):
import warnings
warnings.filterwarnings(type)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def download_youtube_video(youtube_id, ext='mp4', resolution="360p", **kwargs):
import pytube
video_url = f"https://www.youtube.com/watch?v={youtube_id}"
yt = pytube.YouTube(video_url)
try:
streams = yt.streams.filter(
file_extension=ext, res=resolution, progressive=True, **kwargs,
)
# streams[0].download(output_path=save_dir, filename=f"{video_id}.{ext}")
streams[0].download(output_path='/tmp', filename='sample.mp4')
except:
print("Failed to download video: ", video_url)
return None
return "/tmp/sample.mp4"
def check_audio(video_path):
from moviepy.video.io.VideoFileClip import VideoFileClip
try:
return VideoFileClip(video_path).audio is not None
except:
return False
def check_audio_multiple(video_paths, n_jobs=8):
"""Parallelly check if videos have audio"""
iterator = tqdm_iterator(video_paths, desc="Checking audio")
from joblib import Parallel, delayed
return Parallel(n_jobs=n_jobs)(
delayed(check_audio)(video_path) for video_path in iterator
)
def num_trainable_params(model, round=3, verbose=True, return_count=False):
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
model_name = model.__class__.__name__
if round is not None:
value = np.round(n_params / 1e6, round)
unit = "M"
else:
value = n_params
unit = ""
if verbose:
print(f"::: Number of trainable parameters in {model_name}: {value} {unit}")
if return_count:
return n_params
def num_params(model, round=3):
n_params = sum([p.numel() for p in model.parameters()])
model_name = model.__class__.__name__
if round is not None:
value = np.round(n_params / 1e6, round)
unit = "M"
else:
value = n_params
unit = ""
print(f"::: Number of total parameters in {model_name}: {value}{unit}")
def fix_seed(seed=42):
"""Fix all numpy/pytorch/random seeds."""
import random
import torch
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def check_tensor(x):
print(x.shape, x.min(), x.max())
import hashlib
def encode_string(input_string, num_chars=4):
"""
Encodes a given string into a 4-character code using the SHA-256 hash algorithm.
Args:
input_string (str): The input string to be encoded.
num_chars (int): The number of characters to take from the hash digest.
Returns:
str: A 4-character code representing the encoded string.
"""
# Convert the input string to bytes
input_bytes = input_string.encode('utf-8')
# Calculate the SHA-256 hash of the input bytes
hash_object = hashlib.sha256(input_bytes)
# Get the hexadecimal digest of the hash
hex_digest = hash_object.hexdigest()
# Take the first 4 characters of the hexadecimal digest as the code
code = hex_digest[:num_chars]
return code
def flatten_list_of_lists(xss):
return [x for xs in xss for x in xs]
import textwrap
def get_terminal_width():
import shutil
return shutil.get_terminal_size().columns
def wrap_text(text: str, max_length: int = 100) -> str:
"""
Wraps a long string to the specified max_length for easier printing.
Args:
text (str): The input string to wrap.
max_length (int): The maximum length of each line. Default is 80.
Returns:
str: The wrapped text with lines at most max_length long.
"""
terminal_width = get_terminal_width()
max_length = min(max_length, terminal_width)
wrapped_text = textwrap.fill(text, width=max_length)
return wrapped_text
def print_colored(heading, text, color="blue", warp=True):
width = get_terminal_width()
print(colored(heading + "." * (width - len(heading)), color))
if warp:
text = text.split("\n")
text = [wrap_text(t) for t in text]
text = "\n".join(text)
print(text)
# print(wrap_text(text))
else:
print(text)
print("." * width)
def in_notebook():
try:
from IPython import get_ipython
if 'IPKernelApp' not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
def get_run_id():
import datetime
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")