import json import logging import os import re from dataclasses import dataclass, field from typing import Optional import torch from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template from transformers import HfArgumentParser from transformers import Qwen2_5OmniProcessor from dataset.dataset2 import AudioDataset @dataclass class TestArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ MODEL_PATH = "/root/autodl-tmp/Qwen2.5-Omni-7B" # 基础模型路径 LORA_PATH = "/root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324" # LoRA 模型路径 DATA_FILE = "/root/ms-swift/silence_overlaps/test" # 测试数据文件 OUTPUT_DIR = "omini_inference_7B_overlap5sVal_SFT_allset.json" # 推理结果输出目录 model_path: Optional[str] = field(default=MODEL_PATH, metadata={"help": "base model dir"}) lora_path: Optional[str] = field(default=LORA_PATH, metadata={"help": "lora model dir"}) out_file: Optional[str] = field(default=OUTPUT_DIR, metadata={"help": "output file for test"}) data_dir: Optional[str] = field(default=DATA_FILE, metadata={"help": "test data directory"}) DEVICE: Optional[str] = field(default="cuda:0", metadata={"help": "device to use"}) force: Optional[bool] = field(default=False, metadata={"help": "force test"}) batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for processing"}) def __post_init__(self): if self.model_path is None: raise ValueError("config path should not none") if self.data_dir is None: raise ValueError("data directory should not be none") def get_prompt_templates(): prompt_template = ( "You are an expert at analyzing overlapping speech in conversations. Please analyze the speech dialogue and focus specifically on:\n" "Please summarize if any overlaps exceed the 3-second threshold." ) return prompt_template def extract_overall_score(output_str): """从输出中提取X""" score_pattern = r"(\d+)" match = re.search(score_pattern, output_str) if match: try: return int(match.group(1)) except ValueError: pass return None def main(): parser = HfArgumentParser(TestArguments) data_args = parser.parse_args_into_dataclasses()[0] logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logging.info("Starting inference with arguments: %s", data_args) if not data_args.force and os.path.exists(data_args.out_file) and os.path.getsize(data_args.out_file) > 0: logging.info(f"The {data_args.out_file} exists. Do not regenerate it.") return # 设置GPU设备 device = torch.device(data_args.DEVICE if torch.cuda.is_available() else "cpu") logging.info(f"Using device: {device}") # 初始化音频处理器 logging.info("Loading processor...") processor = Qwen2_5OmniProcessor.from_pretrained(data_args.model_path) # 初始化推理引擎 logging.info("Initializing inference engine...") engine = PtEngine(data_args.model_path, adapters=[data_args.lora_path]) engine.processor = processor template = get_template(engine.model.model_meta.template, processor, default_system="You are a helpful assistant.") engine.default_template = template template.processor = processor # 初始化数据集 logging.info("Initializing dataset from %s", data_args.data_dir) dataset = AudioDataset(data_args.data_dir) logging.info(f"Dataset loaded successfully with {len(dataset)} samples") # 获取提示模板 prompt_template = get_prompt_templates() all_outputs = [] batch_size = data_args.batch_size total_batches = (len(dataset) + batch_size - 1) // batch_size logging.info(f"Starting batch processing with batch size {batch_size}, total batches: {total_batches}") for i in range(0, len(dataset), batch_size): current_batch = i // batch_size + 1 logging.info(f"Processing batch {current_batch}/{total_batches}") batch_data = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] # Process each sample batch_outputs = [] for bd in batch_data: # 构建推理请求 infer_request = InferRequest( messages=bd["prompt"], audios=[bd["audio"]] ) # 设置推理配置 request_config = RequestConfig( max_tokens=512, temperature=0, do_sample=False, num_beams=1 ) # 执行推理 resp_list = engine.infer([infer_request], request_config) response = resp_list[0].choices[0].message.content batch_outputs.append(response) all_outputs.extend(batch_outputs) logging.info(f"Completed batch {current_batch}/{total_batches}") final_output = [] correct_count = 0 total_count = 0 true_positive = 0 false_positive = 0 false_negative = 0 for input_example, model_output in zip(dataset, all_outputs): pred_score = extract_overall_score(model_output) gt_score = input_example.get("solution", None) result = { "id": input_example.get("id", None), "gt_score": gt_score, "model_output": model_output, "predicted_score": pred_score } final_output.append(result) if pred_score is not None and gt_score is not None: total_count += 1 if pred_score == gt_score: correct_count += 1 true_positive += 1 else: false_positive += 1 false_negative += 1 accuracy = correct_count / total_count if total_count > 0 else 0 precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0 recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0 # 添加准确率指标到最终输出 metrics = { "accuracy": accuracy, "precision": precision, "recall": recall, "correct_count": correct_count, "total_count": total_count } final_output.append({"metrics": metrics}) logging.info("Saving results to %s", data_args.out_file) with open(data_args.out_file, "w") as f: json.dump(final_output, f, indent=2) logging.info(f"Results saved successfully.") logging.info(f"准确率: {accuracy:.4f} ({correct_count}/{total_count})") logging.info(f"召回率: {recall:.4f}") logging.info(f"精确率: {precision:.4f}") if __name__ == "__main__": main()