import json
import logging
import os
import re
from dataclasses import dataclass, field
from typing import Optional
import torch
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template
from transformers import HfArgumentParser
from transformers import Qwen2_5OmniProcessor
from dataset.dataset2 import AudioDataset
@dataclass
class TestArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
MODEL_PATH = "/root/autodl-tmp/Qwen2.5-Omni-7B" # 基础模型路径
LORA_PATH = "/root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324" # LoRA 模型路径
DATA_FILE = "/root/ms-swift/silence_overlaps/test" # 测试数据文件
OUTPUT_DIR = "omini_inference_7B_overlap5sVal_SFT_allset.json" # 推理结果输出目录
model_path: Optional[str] = field(default=MODEL_PATH, metadata={"help": "base model dir"})
lora_path: Optional[str] = field(default=LORA_PATH, metadata={"help": "lora model dir"})
out_file: Optional[str] = field(default=OUTPUT_DIR, metadata={"help": "output file for test"})
data_dir: Optional[str] = field(default=DATA_FILE, metadata={"help": "test data directory"})
DEVICE: Optional[str] = field(default="cuda:0", metadata={"help": "device to use"})
force: Optional[bool] = field(default=False, metadata={"help": "force test"})
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for processing"})
def __post_init__(self):
if self.model_path is None:
raise ValueError("config path should not none")
if self.data_dir is None:
raise ValueError("data directory should not be none")
def get_prompt_templates():
prompt_template = (
"You are an expert at analyzing overlapping speech in conversations. Please analyze the speech dialogue and focus specifically on:\n"
"Please summarize if any overlaps exceed the 3-second threshold."
)
return prompt_template
def extract_overall_score(output_str):
"""从输出中提取X"""
score_pattern = r"(\d+)"
match = re.search(score_pattern, output_str)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None
def main():
parser = HfArgumentParser(TestArguments)
data_args = parser.parse_args_into_dataclasses()[0]
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logging.info("Starting inference with arguments: %s", data_args)
if not data_args.force and os.path.exists(data_args.out_file) and os.path.getsize(data_args.out_file) > 0:
logging.info(f"The {data_args.out_file} exists. Do not regenerate it.")
return
# 设置GPU设备
device = torch.device(data_args.DEVICE if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# 初始化音频处理器
logging.info("Loading processor...")
processor = Qwen2_5OmniProcessor.from_pretrained(data_args.model_path)
# 初始化推理引擎
logging.info("Initializing inference engine...")
engine = PtEngine(data_args.model_path, adapters=[data_args.lora_path])
engine.processor = processor
template = get_template(engine.model.model_meta.template, processor, default_system="You are a helpful assistant.")
engine.default_template = template
template.processor = processor
# 初始化数据集
logging.info("Initializing dataset from %s", data_args.data_dir)
dataset = AudioDataset(data_args.data_dir)
logging.info(f"Dataset loaded successfully with {len(dataset)} samples")
# 获取提示模板
prompt_template = get_prompt_templates()
all_outputs = []
batch_size = data_args.batch_size
total_batches = (len(dataset) + batch_size - 1) // batch_size
logging.info(f"Starting batch processing with batch size {batch_size}, total batches: {total_batches}")
for i in range(0, len(dataset), batch_size):
current_batch = i // batch_size + 1
logging.info(f"Processing batch {current_batch}/{total_batches}")
batch_data = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
# Process each sample
batch_outputs = []
for bd in batch_data:
# 构建推理请求
infer_request = InferRequest(
messages=bd["prompt"],
audios=[bd["audio"]]
)
# 设置推理配置
request_config = RequestConfig(
max_tokens=512,
temperature=0,
do_sample=False,
num_beams=1
)
# 执行推理
resp_list = engine.infer([infer_request], request_config)
response = resp_list[0].choices[0].message.content
batch_outputs.append(response)
all_outputs.extend(batch_outputs)
logging.info(f"Completed batch {current_batch}/{total_batches}")
final_output = []
correct_count = 0
total_count = 0
true_positive = 0
false_positive = 0
false_negative = 0
for input_example, model_output in zip(dataset, all_outputs):
pred_score = extract_overall_score(model_output)
gt_score = input_example.get("solution", None)
result = {
"id": input_example.get("id", None),
"gt_score": gt_score,
"model_output": model_output,
"predicted_score": pred_score
}
final_output.append(result)
if pred_score is not None and gt_score is not None:
total_count += 1
if pred_score == gt_score:
correct_count += 1
true_positive += 1
else:
false_positive += 1
false_negative += 1
accuracy = correct_count / total_count if total_count > 0 else 0
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
# 添加准确率指标到最终输出
metrics = {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"correct_count": correct_count,
"total_count": total_count
}
final_output.append({"metrics": metrics})
logging.info("Saving results to %s", data_args.out_file)
with open(data_args.out_file, "w") as f:
json.dump(final_output, f, indent=2)
logging.info(f"Results saved successfully.")
logging.info(f"准确率: {accuracy:.4f} ({correct_count}/{total_count})")
logging.info(f"召回率: {recall:.4f}")
logging.info(f"精确率: {precision:.4f}")
if __name__ == "__main__":
main()