File size: 6,406 Bytes
5669b22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | import os
import requests
import tarfile
from pathlib import Path
from tqdm import tqdm
from loguru import logger
def get_github_asset_url(owner, repo, release_tag, filename_without_ext):
"""
Fetch the URL of a GitHub release asset by its filename (without extension).
Args:
owner (str): The owner of the repository.
repo (str): The name of the repository.
release_tag (str): The tag of the release.
filename_without_ext (str): The filename to search for (without extension).
Returns:
str: The download URL of the matched asset, or None if no match is found.
"""
url = f"https://api.github.com/repos/{owner}/{repo}/releases/tags/{release_tag}"
headers = {} # Add authentication headers if needed
try:
# Make a GET request to fetch release data
response = requests.get(url, headers=headers)
response.raise_for_status()
# Parse the JSON response
release_data = response.json()
assets = release_data.get("assets", [])
# Look for a matching file
for asset in assets:
if asset["name"].startswith(filename_without_ext):
logger.info(f"Match found: {asset['name']}")
return asset["browser_download_url"]
# If no match found, log the error
logger.error(
f"No match found for filename: {filename_without_ext} in release {release_tag}."
)
return None
except requests.exceptions.RequestException as e:
logger.error(f"An error occurred while fetching release data: {e}")
return None
def download_and_extract(url: str, output_dir: str) -> Path:
"""
Download a file from a URL and extract it if it is a tar.bz2 archive.
Args:
url (str): The URL to download the file from.
output_dir (str): The directory to save the downloaded file.
Returns:
Path: Path to the extracted directory if it's a tar.bz2 file,
otherwise Path to the downloaded file.
"""
# Create the output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Get the file name from the URL
file_name = url.split("/")[-1]
file_path = os.path.join(output_dir, file_name)
# Extract the root directory name from the filename (removing .tar.bz2)
root_dir = file_name.replace(".tar.bz2", "")
extracted_dir_path = Path(output_dir) / root_dir
# Check if the extracted directory already exists
if extracted_dir_path.exists():
logger.info(
f"✅ The directory {extracted_dir_path} already exists. I would assume that the model is already downloaded and we are ready to go. Skipping download and extraction."
)
return extracted_dir_path
# Download the file
logger.info(f"🏃♂️Downloading {url} to {file_path}...")
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an error for bad status codes
total_size = int(response.headers.get("content-length", 0))
logger.debug(f"Total file size: {total_size / 1024 / 1024:.2f} MB")
with (
open(file_path, "wb") as f,
tqdm(
desc=file_name,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as pbar,
):
for chunk in response.iter_content(chunk_size=8192):
size = f.write(chunk)
pbar.update(size)
logger.info(f"Downloaded {file_name} successfully.")
# Extract the tar.bz2 file
if file_name.endswith(".tar.bz2"):
logger.info(f"Extracting {file_name}...")
with tarfile.open(file_path, "r:bz2") as tar:
tar.extractall(path=output_dir)
logger.info("Extraction completed.")
# Delete the compressed file
os.remove(file_path)
logger.debug(f"Deleted the compressed file: {file_name}")
return extracted_dir_path
else:
logger.warning("The downloaded file is not a tar.bz2 archive.")
return Path(file_path)
def check_and_extract_local_file(url: str, output_dir: str) -> Path | None:
"""
Check if a local file exists and extract it if it is a tar.bz2 archive.
Args:
url (str): The URL of the file.
output_dir (str): The directory to save the extracted files.
Returns:
Path | None: Path to the extracted directory if it's a tar.bz2 file,
otherwise None.
"""
# Get the file name from the URL
file_name = url.split("/")[-1]
compressed_path = Path(output_dir) / file_name
# Check if the compressed file exists and is a tar.bz2 archive
extracted_dir = Path(output_dir) / file_name.replace(".tar.bz2", "")
if extracted_dir.exists():
logger.info(
f"✅ Extracted directory exists: {extracted_dir}, no operation needed."
)
return extracted_dir
if compressed_path.exists() and file_name.endswith(".tar.bz2"):
logger.info(f"🔍 Found local archive file: {compressed_path}")
try:
logger.info("⏳ Extracting archive file...")
with tarfile.open(compressed_path, "r:bz2") as tar:
tar.extractall(path=output_dir)
logger.success(f"Extracted archive to the path: {extracted_dir}")
os.remove(compressed_path) # Remove the compressed file
return extracted_dir
except Exception as e:
logger.error(f"Fail to extract file: {str(e)}")
return None
logger.warning(f"Local file not found or not a tar.bz2 archive: {compressed_path}")
return None
if __name__ == "__main__":
url = "https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2"
output_dir = "./models"
# Try local extraction first.
local_result = check_and_extract_local_file(url, output_dir)
# Download if not available locally.
if local_result is None:
logger.info("Local archive not found. Starting download...")
download_and_extract(url, output_dir)
else:
logger.info("Extraction completed using local file.")
|