interactSpeech / .ipynb_checkpoints /test_qwenOmni-checkpoint.py
Student0809's picture
Add files using upload-large-folder tool
e791fa3 verified
import json
import logging
import os
import re
from dataclasses import dataclass, field
from typing import Optional
import torch
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template
from transformers import HfArgumentParser
from transformers import Qwen2_5OmniProcessor
from dataset.dataset2 import AudioDataset
@dataclass
class TestArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
MODEL_PATH = "/root/autodl-tmp/Qwen2.5-Omni-7B" # 基础模型路径
LORA_PATH = "/root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324" # LoRA 模型路径
DATA_FILE = "/root/ms-swift/silence_overlaps/test" # 测试数据文件
OUTPUT_DIR = "omini_inference_7B_overlap5sVal_SFT_allset.json" # 推理结果输出目录
model_path: Optional[str] = field(default=MODEL_PATH, metadata={"help": "base model dir"})
lora_path: Optional[str] = field(default=LORA_PATH, metadata={"help": "lora model dir"})
out_file: Optional[str] = field(default=OUTPUT_DIR, metadata={"help": "output file for test"})
data_dir: Optional[str] = field(default=DATA_FILE, metadata={"help": "test data directory"})
DEVICE: Optional[str] = field(default="cuda:0", metadata={"help": "device to use"})
force: Optional[bool] = field(default=False, metadata={"help": "force test"})
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for processing"})
def __post_init__(self):
if self.model_path is None:
raise ValueError("config path should not none")
if self.data_dir is None:
raise ValueError("data directory should not be none")
def get_prompt_templates():
prompt_template = (
"You are an expert at analyzing overlapping speech in conversations. Please analyze the speech dialogue and focus specifically on:\n"
"Please summarize if any overlaps exceed the 3-second threshold."
)
return prompt_template
def extract_overall_score(output_str):
"""从输出中提取<overall score>X</overall score>"""
score_pattern = r"<overall score>(\d+)</overall score>"
match = re.search(score_pattern, output_str)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None
def main():
parser = HfArgumentParser(TestArguments)
data_args = parser.parse_args_into_dataclasses()[0]
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logging.info("Starting inference with arguments: %s", data_args)
if not data_args.force and os.path.exists(data_args.out_file) and os.path.getsize(data_args.out_file) > 0:
logging.info(f"The {data_args.out_file} exists. Do not regenerate it.")
return
# 设置GPU设备
device = torch.device(data_args.DEVICE if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# 初始化音频处理器
logging.info("Loading processor...")
processor = Qwen2_5OmniProcessor.from_pretrained(data_args.model_path)
# 初始化推理引擎
logging.info("Initializing inference engine...")
engine = PtEngine(data_args.model_path, adapters=[data_args.lora_path])
engine.processor = processor
template = get_template(engine.model.model_meta.template, processor, default_system="You are a helpful assistant.")
engine.default_template = template
template.processor = processor
# 初始化数据集
logging.info("Initializing dataset from %s", data_args.data_dir)
dataset = AudioDataset(data_args.data_dir)
logging.info(f"Dataset loaded successfully with {len(dataset)} samples")
# 获取提示模板
prompt_template = get_prompt_templates()
all_outputs = []
batch_size = data_args.batch_size
total_batches = (len(dataset) + batch_size - 1) // batch_size
logging.info(f"Starting batch processing with batch size {batch_size}, total batches: {total_batches}")
for i in range(0, len(dataset), batch_size):
current_batch = i // batch_size + 1
logging.info(f"Processing batch {current_batch}/{total_batches}")
batch_data = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
# Process each sample
batch_outputs = []
for bd in batch_data:
# 构建推理请求
infer_request = InferRequest(
messages=bd["prompt"],
audios=[bd["audio"]]
)
# 设置推理配置
request_config = RequestConfig(
max_tokens=512,
temperature=0,
do_sample=False,
num_beams=1
)
# 执行推理
resp_list = engine.infer([infer_request], request_config)
response = resp_list[0].choices[0].message.content
batch_outputs.append(response)
all_outputs.extend(batch_outputs)
logging.info(f"Completed batch {current_batch}/{total_batches}")
final_output = []
correct_count = 0
total_count = 0
true_positive = 0
false_positive = 0
false_negative = 0
for input_example, model_output in zip(dataset, all_outputs):
pred_score = extract_overall_score(model_output)
gt_score = input_example.get("solution", None)
result = {
"id": input_example.get("id", None),
"gt_score": gt_score,
"model_output": model_output,
"predicted_score": pred_score
}
final_output.append(result)
if pred_score is not None and gt_score is not None:
total_count += 1
if pred_score == gt_score:
correct_count += 1
true_positive += 1
else:
false_positive += 1
false_negative += 1
accuracy = correct_count / total_count if total_count > 0 else 0
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
# 添加准确率指标到最终输出
metrics = {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"correct_count": correct_count,
"total_count": total_count
}
final_output.append({"metrics": metrics})
logging.info("Saving results to %s", data_args.out_file)
with open(data_args.out_file, "w") as f:
json.dump(final_output, f, indent=2)
logging.info(f"Results saved successfully.")
logging.info(f"准确率: {accuracy:.4f} ({correct_count}/{total_count})")
logging.info(f"召回率: {recall:.4f}")
logging.info(f"精确率: {precision:.4f}")
if __name__ == "__main__":
main()