File size: 6,978 Bytes
e791fa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import json
import logging
import os
import re
from dataclasses import dataclass, field
from typing import Optional
import torch
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template
from transformers import HfArgumentParser
from transformers import Qwen2_5OmniProcessor
from dataset.dataset2 import AudioDataset
@dataclass
class TestArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
MODEL_PATH = "/root/autodl-tmp/Qwen2.5-Omni-7B" # 基础模型路径
LORA_PATH = "/root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324" # LoRA 模型路径
DATA_FILE = "/root/ms-swift/silence_overlaps/test" # 测试数据文件
OUTPUT_DIR = "omini_inference_7B_overlap5sVal_SFT_allset.json" # 推理结果输出目录
model_path: Optional[str] = field(default=MODEL_PATH, metadata={"help": "base model dir"})
lora_path: Optional[str] = field(default=LORA_PATH, metadata={"help": "lora model dir"})
out_file: Optional[str] = field(default=OUTPUT_DIR, metadata={"help": "output file for test"})
data_dir: Optional[str] = field(default=DATA_FILE, metadata={"help": "test data directory"})
DEVICE: Optional[str] = field(default="cuda:0", metadata={"help": "device to use"})
force: Optional[bool] = field(default=False, metadata={"help": "force test"})
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for processing"})
def __post_init__(self):
if self.model_path is None:
raise ValueError("config path should not none")
if self.data_dir is None:
raise ValueError("data directory should not be none")
def get_prompt_templates():
prompt_template = (
"You are an expert at analyzing overlapping speech in conversations. Please analyze the speech dialogue and focus specifically on:\n"
"Please summarize if any overlaps exceed the 3-second threshold."
)
return prompt_template
def extract_overall_score(output_str):
"""从输出中提取<overall score>X</overall score>"""
score_pattern = r"<overall score>(\d+)</overall score>"
match = re.search(score_pattern, output_str)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None
def main():
parser = HfArgumentParser(TestArguments)
data_args = parser.parse_args_into_dataclasses()[0]
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logging.info("Starting inference with arguments: %s", data_args)
if not data_args.force and os.path.exists(data_args.out_file) and os.path.getsize(data_args.out_file) > 0:
logging.info(f"The {data_args.out_file} exists. Do not regenerate it.")
return
# 设置GPU设备
device = torch.device(data_args.DEVICE if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# 初始化音频处理器
logging.info("Loading processor...")
processor = Qwen2_5OmniProcessor.from_pretrained(data_args.model_path)
# 初始化推理引擎
logging.info("Initializing inference engine...")
engine = PtEngine(data_args.model_path, adapters=[data_args.lora_path])
engine.processor = processor
template = get_template(engine.model.model_meta.template, processor, default_system="You are a helpful assistant.")
engine.default_template = template
template.processor = processor
# 初始化数据集
logging.info("Initializing dataset from %s", data_args.data_dir)
dataset = AudioDataset(data_args.data_dir)
logging.info(f"Dataset loaded successfully with {len(dataset)} samples")
# 获取提示模板
prompt_template = get_prompt_templates()
all_outputs = []
batch_size = data_args.batch_size
total_batches = (len(dataset) + batch_size - 1) // batch_size
logging.info(f"Starting batch processing with batch size {batch_size}, total batches: {total_batches}")
for i in range(0, len(dataset), batch_size):
current_batch = i // batch_size + 1
logging.info(f"Processing batch {current_batch}/{total_batches}")
batch_data = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
# Process each sample
batch_outputs = []
for bd in batch_data:
# 构建推理请求
infer_request = InferRequest(
messages=bd["prompt"],
audios=[bd["audio"]]
)
# 设置推理配置
request_config = RequestConfig(
max_tokens=512,
temperature=0,
do_sample=False,
num_beams=1
)
# 执行推理
resp_list = engine.infer([infer_request], request_config)
response = resp_list[0].choices[0].message.content
batch_outputs.append(response)
all_outputs.extend(batch_outputs)
logging.info(f"Completed batch {current_batch}/{total_batches}")
final_output = []
correct_count = 0
total_count = 0
true_positive = 0
false_positive = 0
false_negative = 0
for input_example, model_output in zip(dataset, all_outputs):
pred_score = extract_overall_score(model_output)
gt_score = input_example.get("solution", None)
result = {
"id": input_example.get("id", None),
"gt_score": gt_score,
"model_output": model_output,
"predicted_score": pred_score
}
final_output.append(result)
if pred_score is not None and gt_score is not None:
total_count += 1
if pred_score == gt_score:
correct_count += 1
true_positive += 1
else:
false_positive += 1
false_negative += 1
accuracy = correct_count / total_count if total_count > 0 else 0
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
# 添加准确率指标到最终输出
metrics = {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"correct_count": correct_count,
"total_count": total_count
}
final_output.append({"metrics": metrics})
logging.info("Saving results to %s", data_args.out_file)
with open(data_args.out_file, "w") as f:
json.dump(final_output, f, indent=2)
logging.info(f"Results saved successfully.")
logging.info(f"准确率: {accuracy:.4f} ({correct_count}/{total_count})")
logging.info(f"召回率: {recall:.4f}")
logging.info(f"精确率: {precision:.4f}")
if __name__ == "__main__":
main() |