audio-classifier / inference.py
ahmedtarekabd's picture
Add Models & files.
4c8f740
import argparse
import os
import time
from glob import glob
import pandas as pd
from pathlib import Path
from modules.preprocessing import AudioPreprocessor
from modules.feature_extraction import FeatureExtractor
from models.lightgbm import LightGBMModel
from models.xgboost import XGBoostModel
from modules.pipelines import ModelPipeline
import warnings
warnings.filterwarnings("ignore")
MODEL_NAME = {
"XGBoost": XGBoostModel,
"LightGBM": LightGBMModel,
}
def run_batch_inference(model, input_folder, output_folder, sr=16000, feature_mode="traditional"):
preprocessor = AudioPreprocessor()
extractor = FeatureExtractor()
# Sort files in the correct order
files = sorted(glob(os.path.join(input_folder, "*")), key=lambda x: int(Path(x).stem))
# Overwrite if exsists
results_path = os.path.join(output_folder, "results.txt")
time_path = os.path.join(output_folder, "time.txt")
with open(results_path, "w") as f: pass
with open(time_path, "w") as f: pass
pred = 0
for file in files:
# Measure inference time
start_time = time.time()
y = preprocessor.preprocess(preprocessor.load_audio(str(file), sr=sr))
if y is not None:
x = extractor.extract(y, sr=sr, mode=feature_mode, n_mfcc=20)
pred = model.predict([x])[0]
end_time = time.time()
# Save results to results.txt
with open(results_path, "a") as f:
f.write(f"{pred}\n")
# Save inference time to time.txt
with open(time_path, "a") as f:
f.write(f"{end_time - start_time:.6f}\n")
print(f"✅ Results saved to {results_path}")
print(f"✅ Inference time saved to {time_path}")
def main(input_path, model_name, output_folder):
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input path {input_path} does not exist.")
if model_name not in MODEL_NAME.keys():
raise ValueError(f"Model name {model_name} is not valid. Choose from {list(MODEL_NAME.keys())}.")
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
print(f"Output folder {output_folder} created.")
model = ModelPipeline(model=MODEL_NAME[model_name])
model.load_model_from_registry(model_name=model_name)
print("✅ Model loaded successfully")
run_batch_inference(model, input_path, output_folder)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, default="/data", help="Path to the input folder containing test audio files. Default is '/data'.")
parser.add_argument('--model-name', type=str, default="XGBoost", help="Name of the model to use for inference. Default is 'XGBoost'.")
parser.add_argument('--team_id', type=str, required=True, help="Team ID for output folder.")
args = parser.parse_args()
output_folder = os.path.join("/results", args.team_id)
print(f"Input Path: {args.input_path}")
print(f"Model Name: {args.model_name}")
print(f"Output Folder: {output_folder}")
main(args.input_path, args.model_name, output_folder)