ITFormer / inference.py
a12354's picture
Add files using upload-large-folder tool
f48983a verified
Raw
History Blame Contribute Delete
15 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Inference script for Time Language Model (TLM).
Performs inference on the test set and saves results and evaluation metrics.
"""
import os
import yaml
import json
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*TRANSFORMERS_CACHE.*")
import argparse
import random
import logging
from transformers.utils import logging as transformers_logging
import numpy as np
import torch
from transformers import AutoTokenizer, AutoProcessor
from dataset.dataset import TsQaDataset, DataCollator
from models.TimeLanguageModel import TLM
from datetime import datetime
from tqdm import tqdm
from torch.utils.data import DataLoader
from utils.metrics import open_question_metrics, closed_question_metrics
from typing import List, Dict, Any
from accelerate import Accelerator
def set_seed(seed=42):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def save_results(results, output_dir, config_name):
"""Save inference results to a JSON file."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"inference_results_{config_name}_{timestamp}.json"
os.makedirs(output_dir, exist_ok=True)
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=4, ensure_ascii=False)
print(f"\n✅ Inference results saved to: {filepath}")
def count_model_parameters(model):
"""Count the number of parameters in each model component (in millions)."""
def count_params(module):
if module is None:
return 0
return sum(p.numel() for p in module.parameters())
llm_params = count_params(model.llm_model) / 1e6
itformer_params = count_params(model.itformer) / 1e6
ts_encoder_params = count_params(model.ts_encoder) / 1e6
total_params = count_params(model) / 1e6
return {
'llm': llm_params,
'itformer': itformer_params,
'ts_encoder': ts_encoder_params,
'total': total_params
}
def main_inference(args):
"""Main inference pipeline."""
# Initialize accelerator
accelerator = Accelerator()
# 设置日志级别,主进程显示进度,从进程静默
if accelerator.is_local_main_process:
transformers_logging.set_verbosity_info()
else:
transformers_logging.set_verbosity_error()
logging.disable(logging.CRITICAL)
set_seed(args.seed)
# Print only in the main process
if accelerator.is_main_process:
print("🚀 Starting inference process...")
print(f"🔧 Using config file: {args.config}")
print(f"💾 Results will be saved to: {args.output_dir}")
# Load config
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
args.__dict__.update(config)
if accelerator.is_main_process:
print("\n⚙️ Loading model from checkpoint...")
# Import TLMConfig
from models.TimeLanguageModel import TLMConfig
# Determine LLM model path based on ITFormer model size
if '0.5B' in args.model_checkpoint:
llm_model_path = 'LLM/Qwen2.5-0.5B-Instruct'
elif '3B' in args.model_checkpoint:
llm_model_path = 'LLM/Qwen2.5-3B-Instruct'
elif '7B' in args.model_checkpoint:
llm_model_path = 'LLM/Qwen2.5-7B-Instruct'
else:
# Default to 7B if size cannot be determined
llm_model_path = 'LLM/Qwen2.5-7B-Instruct'
# 将推断出的 llm_model_path 存入 args,防止后续 compute_metrics_from_results 报错
args.llm_model_path = llm_model_path
if accelerator.is_main_process:
print(f"🔗 Using LLM model: {llm_model_path}")
# Create TLMConfig
tlm_config = TLMConfig(
llm_model_path=llm_model_path,
freeze_ts_model=True,
ts_pad_num=args.prefix_num,
# 显式补全 ITFormer 和 TS Encoder 需要的字段名
prefix_num=args.prefix_num,
it_d_model=config.get('it_d_model', 896),
it_n_heads=config.get('it_n_heads', 16),
it_layers=config.get('it_layers', 2),
it_dropout=config.get('it_dropout', 0.1),
adapter_type=config.get('adapter_type', 'itformer'),
dropout=config.get('dropout', 0.1),
d_model=config.get('d_model', 512),
n_heads=config.get('n_heads', 8),
e_layers=config.get('e_layers', 4),
patch_len=config.get('patch_len', 60),
stride=config.get('stride', 60),
input_len=config.get('input_len', 600),
)
# Load model - let from_pretrained handle configuration loading
from models.TimeLanguageModel import TLM
model = TLM.from_pretrained(args.model_checkpoint, config=tlm_config, ts_config=tlm_config)
# Print model parameter statistics only in the main process
if accelerator.is_main_process:
param_counts = count_model_parameters(model)
print(f"\n🔢 Model parameter counts (in M):")
print(f" LLM: {param_counts['llm']:.2f}M")
print(f" ITFormer: {param_counts['itformer']:.2f}M")
print(f" TS_Encoder: {param_counts['ts_encoder']:.2f}M")
print(f" Total: {param_counts['total']:.2f}M")
# Prepare model with accelerator
model = accelerator.prepare(model)
model.eval()
if accelerator.is_main_process:
print(f"✅ Model loaded successfully!")
print("\n📊 Preparing test dataset...")
# Load tokenizer
if os.path.exists(os.path.join(args.model_checkpoint, "tokenizer.json")):
if accelerator.is_main_process:
print("Loading tokenizer from checkpoint")
tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint)
else:
if accelerator.is_main_process:
print("Loading tokenizer from original model")
tokenizer = model.tokenizer
# Add special token if not present
if '<|image_pad|>' not in tokenizer.get_vocab():
if accelerator.is_main_process:
print("Adding <|image_pad|> token to tokenizer")
tokenizer.add_tokens(['<|image_pad|>'])
model.llm_model.resize_token_embeddings(len(tokenizer))
tokenizer.padding_side = 'left'
# Simple config for dataset
class SimpleConfig:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
tlmconfig = SimpleConfig(ts_pad_num=args.prefix_num)
test_dataset = TsQaDataset(
args.ts_path_test,
args.qa_path_test,
tokenizer,
tokenizer, # Use tokenizer as processor
tlmconfig
)
data_collator = DataCollator(tokenizer=tokenizer)
test_loader = DataLoader(
test_dataset,
batch_size=args.batch_size,
collate_fn=data_collator,
num_workers=args.num_workers
)
# Prepare dataloader with accelerator
test_loader = accelerator.prepare(test_loader)
if accelerator.is_main_process:
print(f"📁 Test set size: {len(test_dataset)} samples")
print(f"🔢 Batch size: {args.batch_size}, Total batches: {len(test_loader)}")
print("\n🔍 Starting test set inference...")
results = []
with torch.no_grad():
# Show progress bar only in the main process
if accelerator.is_main_process:
batch_iterator = tqdm(test_loader, desc="Inference progress")
else:
batch_iterator = test_loader
for batch_idx, batch in enumerate(batch_iterator):
# if batch_idx ==16:
# break
# Data is already on the correct device thanks to accelerator
# 使用 unwrap_model 来调用原始模型的 generate 方法
unwrapped_model = accelerator.unwrap_model(model)
generated_ids = unwrapped_model.generate(
input_ids=batch['input_ids'],
query_ids=batch['query_ids'],
ts_values=batch['ts_values'],
stage=batch['stage'],
attention_mask=batch['attention_mask'],
max_new_tokens=args.max_new_tokens,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
# Optimization parameters, consistent with Time-QA
num_beams=1, # Greedy search, fastest
temperature=1.0, # Avoid extra computation
top_p=None, # Disable nucleus sampling
top_k=None, # Disable top-k sampling
repetition_penalty=1.0, # Disable repetition penalty
length_penalty=1.0, # Disable length penalty
no_repeat_ngram_size=0, # Disable n-gram repetition check
output_scores=False, # Do not output scores
output_attentions=False, # Do not output attention
output_hidden_states=False, # Do not output hidden states
return_dict_in_generate=False, # Simplify return format
)
batch_predictions = tokenizer.batch_decode(
generated_ids,
skip_special_tokens=True
)
batch_labels = tokenizer.batch_decode(
batch['labels'],
skip_special_tokens=True
)
for i in range(len(batch_predictions)):
prediction = batch_predictions[i].split('assistant\n')[-1]
results.append({
"index": batch['index'][i].item(),
"stage": batch['stage'][i].item(),
"input": tokenizer.decode(batch['input_ids'][i], skip_special_tokens=True),
"prediction": prediction,
"label": batch_labels[i],
"is_correct": prediction.strip() == batch_labels[i].strip()
})
if accelerator.is_main_process:
print("\n📊 Calculating evaluation metrics...")
# 汇总所有进程的结果
all_results = [None] * accelerator.num_processes
# 使用 gather_object 汇总 list
import torch.distributed as dist
if dist.is_initialized():
dist.all_gather_object(all_results, results)
# 将嵌套列表打平
results = [item for sublist in all_results for item in sublist]
if accelerator.is_main_process:
# 去重(防止分布式采样导致的重复数据)
seen_indices = set()
unique_results = []
for r in results:
if r['index'] not in seen_indices:
unique_results.append(r)
seen_indices.add(r['index'])
results = sorted(unique_results, key=lambda x: x['index'])
metrics = compute_metrics_from_results(results, args)
print_metrics(metrics)
config_base = os.path.basename(args.config).split('.yaml')[0]
save_results(results, args.output_dir, config_base)
save_metrics(metrics, args.output_dir, config_base)
def compute_metrics_from_results(results: List[Dict],args) -> Dict[str, Any]:
"""Compute evaluation metrics for each stage from inference results."""
stage1_data = [r for r in results if r['stage'] == 1]
stage2_data = [r for r in results if r['stage'] == 2]
stage3_data = [r for r in results if r['stage'] == 3]
stage4_data = [r for r in results if r['stage'] == 4]
# Create special_id, consistent with Time-QA
from transformers import AutoTokenizer
# 为了和设置的模型对齐,tokenizer应与主模型保持一致
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_path)
special_id = tokenizer.all_special_ids
common_punctuations = [".", ",", ":", ";", "!", "?", "(", ")", "[", "]", "{", "}", "-", "_", "\"", "'"]
punctuation_ids = tokenizer.convert_tokens_to_ids(common_punctuations)
special_id.extend(punctuation_ids)
metrics = {}
if stage1_data:
stage1_metrics = open_question_metrics(
[r['prediction'] for r in stage1_data],
[r['label'] for r in stage1_data],
special_id
)
metrics.update({f"stage1_{k}": v for k, v in stage1_metrics.items()})
if stage2_data:
stage2_metrics = closed_question_metrics(
[r['prediction'] for r in stage2_data],
[r['label'] for r in stage2_data],
special_id
)
metrics.update({f"stage2_{k}": v for k, v in stage2_metrics.items()})
if stage3_data:
stage3_metrics = closed_question_metrics(
[r['prediction'] for r in stage3_data],
[r['label'] for r in stage3_data],
special_id
)
metrics.update({f"stage3_{k}": v for k, v in stage3_metrics.items()})
if stage4_data:
stage4_metrics = open_question_metrics(
[r['prediction'] for r in stage4_data],
[r['label'] for r in stage4_data],
special_id
)
metrics.update({f"stage4_{k}": v for k, v in stage4_metrics.items()})
return metrics
def print_metrics(metrics: Dict[str, Any]):
"""Print evaluation metrics for each stage."""
print("\n📈 Evaluation results:")
for stage in [1, 2, 3, 4]:
stage_metrics = {k.replace(f"stage{stage}_", ""): v
for k, v in metrics.items() if k.startswith(f"stage{stage}_")}
if stage_metrics:
print(f"\n🔹 Stage {stage} metrics:")
for metric, value in stage_metrics.items():
print(f" {metric}: {value:.4f}")
def save_metrics(metrics: Dict[str, Any], output_dir: str, config_base: str):
"""Save evaluation metrics to a JSON file."""
metrics_file = os.path.join(output_dir, f"{config_base}_metrics.json")
with open(metrics_file, 'w', encoding='utf-8') as f:
json.dump(metrics, f, indent=2, ensure_ascii=False)
print(f"\n💾 Evaluation metrics saved to: {metrics_file}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TLM Inference')
parser.add_argument('--config', type=str, default='yaml/infer.yaml', help='YAML config file')
parser.add_argument('--output_dir', type=str, default='inference_results', help='Output directory')
parser.add_argument('--model_checkpoint', type=str, default='checkpoints/ITFormer-0.5B', help='Model checkpoint path')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
parser.add_argument('--batch_size', type=int, default=12, help='Batch size')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers')
parser.add_argument('--max_new_tokens', type=int, default=128, help='Maximum new tokens to generate')
args = parser.parse_args()
main_inference(args)