|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template |
|
|
from transformers import HfArgumentParser |
|
|
from transformers import Qwen2_5OmniProcessor |
|
|
from dataset.dataset2 import AudioDataset |
|
|
|
|
|
@dataclass |
|
|
class TestArguments: |
|
|
""" |
|
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
|
""" |
|
|
MODEL_PATH = "/root/autodl-tmp/Qwen2.5-Omni-7B" |
|
|
LORA_PATH = "/root/autodl-tmp/output_7B_Lora/v2-20250608-171618/checkpoint-324" |
|
|
DATA_FILE = "/root/ms-swift/silence_overlaps/test" |
|
|
OUTPUT_DIR = "omini_inference_7B_overlap5sVal_SFT_allset.json" |
|
|
|
|
|
model_path: Optional[str] = field(default=MODEL_PATH, metadata={"help": "base model dir"}) |
|
|
lora_path: Optional[str] = field(default=LORA_PATH, metadata={"help": "lora model dir"}) |
|
|
out_file: Optional[str] = field(default=OUTPUT_DIR, metadata={"help": "output file for test"}) |
|
|
data_dir: Optional[str] = field(default=DATA_FILE, metadata={"help": "test data directory"}) |
|
|
DEVICE: Optional[str] = field(default="cuda:0", metadata={"help": "device to use"}) |
|
|
force: Optional[bool] = field(default=False, metadata={"help": "force test"}) |
|
|
batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for processing"}) |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.model_path is None: |
|
|
raise ValueError("config path should not none") |
|
|
if self.data_dir is None: |
|
|
raise ValueError("data directory should not be none") |
|
|
|
|
|
def get_prompt_templates(): |
|
|
prompt_template = ( |
|
|
"You are an expert at analyzing overlapping speech in conversations. Please analyze the speech dialogue and focus specifically on:\n" |
|
|
"Please summarize if any overlaps exceed the 3-second threshold." |
|
|
) |
|
|
return prompt_template |
|
|
|
|
|
def extract_overall_score(output_str): |
|
|
"""从输出中提取<overall score>X</overall score>""" |
|
|
score_pattern = r"<overall score>(\d+)</overall score>" |
|
|
match = re.search(score_pattern, output_str) |
|
|
if match: |
|
|
try: |
|
|
return int(match.group(1)) |
|
|
except ValueError: |
|
|
pass |
|
|
return None |
|
|
|
|
|
def main(): |
|
|
parser = HfArgumentParser(TestArguments) |
|
|
data_args = parser.parse_args_into_dataclasses()[0] |
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
logging.info("Starting inference with arguments: %s", data_args) |
|
|
|
|
|
if not data_args.force and os.path.exists(data_args.out_file) and os.path.getsize(data_args.out_file) > 0: |
|
|
logging.info(f"The {data_args.out_file} exists. Do not regenerate it.") |
|
|
return |
|
|
|
|
|
|
|
|
device = torch.device(data_args.DEVICE if torch.cuda.is_available() else "cpu") |
|
|
logging.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
logging.info("Loading processor...") |
|
|
processor = Qwen2_5OmniProcessor.from_pretrained(data_args.model_path) |
|
|
|
|
|
|
|
|
logging.info("Initializing inference engine...") |
|
|
engine = PtEngine(data_args.model_path, adapters=[data_args.lora_path]) |
|
|
engine.processor = processor |
|
|
template = get_template(engine.model.model_meta.template, processor, default_system="You are a helpful assistant.") |
|
|
engine.default_template = template |
|
|
template.processor = processor |
|
|
|
|
|
logging.info("Initializing dataset from %s", data_args.data_dir) |
|
|
dataset = AudioDataset(data_args.data_dir) |
|
|
logging.info(f"Dataset loaded successfully with {len(dataset)} samples") |
|
|
|
|
|
|
|
|
prompt_template = get_prompt_templates() |
|
|
|
|
|
all_outputs = [] |
|
|
batch_size = data_args.batch_size |
|
|
total_batches = (len(dataset) + batch_size - 1) // batch_size |
|
|
logging.info(f"Starting batch processing with batch size {batch_size}, total batches: {total_batches}") |
|
|
|
|
|
for i in range(0, len(dataset), batch_size): |
|
|
current_batch = i // batch_size + 1 |
|
|
logging.info(f"Processing batch {current_batch}/{total_batches}") |
|
|
|
|
|
batch_data = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] |
|
|
|
|
|
|
|
|
batch_outputs = [] |
|
|
for bd in batch_data: |
|
|
|
|
|
infer_request = InferRequest( |
|
|
messages=bd["prompt"], |
|
|
audios=[bd["audio"]] |
|
|
) |
|
|
|
|
|
|
|
|
request_config = RequestConfig( |
|
|
max_tokens=512, |
|
|
temperature=0, |
|
|
do_sample=False, |
|
|
num_beams=1 |
|
|
) |
|
|
|
|
|
|
|
|
resp_list = engine.infer([infer_request], request_config) |
|
|
response = resp_list[0].choices[0].message.content |
|
|
batch_outputs.append(response) |
|
|
|
|
|
all_outputs.extend(batch_outputs) |
|
|
logging.info(f"Completed batch {current_batch}/{total_batches}") |
|
|
|
|
|
final_output = [] |
|
|
correct_count = 0 |
|
|
total_count = 0 |
|
|
true_positive = 0 |
|
|
false_positive = 0 |
|
|
false_negative = 0 |
|
|
|
|
|
for input_example, model_output in zip(dataset, all_outputs): |
|
|
pred_score = extract_overall_score(model_output) |
|
|
gt_score = input_example.get("solution", None) |
|
|
|
|
|
result = { |
|
|
"id": input_example.get("id", None), |
|
|
"gt_score": gt_score, |
|
|
"model_output": model_output, |
|
|
"predicted_score": pred_score |
|
|
} |
|
|
final_output.append(result) |
|
|
|
|
|
if pred_score is not None and gt_score is not None: |
|
|
total_count += 1 |
|
|
if pred_score == gt_score: |
|
|
correct_count += 1 |
|
|
true_positive += 1 |
|
|
else: |
|
|
false_positive += 1 |
|
|
false_negative += 1 |
|
|
|
|
|
accuracy = correct_count / total_count if total_count > 0 else 0 |
|
|
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0 |
|
|
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0 |
|
|
|
|
|
|
|
|
metrics = { |
|
|
"accuracy": accuracy, |
|
|
"precision": precision, |
|
|
"recall": recall, |
|
|
"correct_count": correct_count, |
|
|
"total_count": total_count |
|
|
} |
|
|
final_output.append({"metrics": metrics}) |
|
|
|
|
|
logging.info("Saving results to %s", data_args.out_file) |
|
|
with open(data_args.out_file, "w") as f: |
|
|
json.dump(final_output, f, indent=2) |
|
|
|
|
|
logging.info(f"Results saved successfully.") |
|
|
logging.info(f"准确率: {accuracy:.4f} ({correct_count}/{total_count})") |
|
|
logging.info(f"召回率: {recall:.4f}") |
|
|
logging.info(f"精确率: {precision:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |