| | |
| | |
| |
|
| | import os |
| | import json |
| | import torch |
| | import argparse |
| | from transformers import AutoProcessor |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | from huggingface_hub import login |
| | from datasets import load_dataset |
| | from dotenv import load_dotenv |
| |
|
| | |
| | from util import * |
| | from model import MoralEmotionVLClassifier |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="inference") |
| | parser.add_argument( |
| | '-l', |
| | '--language', |
| | type=str, |
| | required=True, |
| | help='type only korean or english' |
| | ) |
| |
|
| | parser.add_argument( |
| | '-g', |
| | '--gpu', |
| | type=str, |
| | default='auto', |
| | help='type only number(ex. cuda:0 -> 0)' |
| | ) |
| |
|
| | parser.add_argument( |
| | '-e', |
| | '--emotion', |
| | type=str, |
| | required=True, |
| | help='type the target moral emotion(ex. Other-condemning, Other-praising, Other-suffering, Self-conscious, Non-moral emotion, Neutral)' |
| | ) |
| |
|
| | parser.add_argument( |
| | '-s', |
| | '--save_dir', |
| | type=str, |
| | required=True, |
| | help='directory path of model and save results' |
| | ) |
| |
|
| | parser.add_argument( |
| | '-d', |
| | '--data_path', |
| | type=str, |
| | required=True, |
| | help='data path in huggingface(ex. test/dataset_ko)' |
| | ) |
| |
|
| | parser.add_argument( |
| | '-m', |
| | '--max_memory', |
| | type=str, |
| | default=None, |
| | help='comma-separated GPU memory limits (ex. "0:10GiB,1:20GiB,2:0GiB,3:23GiB")' |
| | ) |
| |
|
| | return parser.parse_args() |
| |
|
| | def predict(model, test_dataloader, threshold): |
| | results = [] |
| |
|
| | with torch.no_grad(): |
| | for batch in tqdm(test_dataloader): |
| | ids = batch.pop('ids') |
| | outputs = model(**batch) |
| | logits = outputs.logits |
| | |
| | probabilities = torch.sigmoid(logits) |
| | predicted_class = (probabilities > threshold).long() |
| | |
| | for i in range(len(ids)): |
| | results.append({ |
| | "id": ids[i], |
| | "prediction": model.id2label[predicted_class[i].item()], |
| | "probabilities": probabilities[i].item() |
| | }) |
| |
|
| |
|
| | return results |
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | model_name = 'Qwen/Qwen2-VL-7B-Instruct' |
| | device = args.gpu if args.gpu == "auto" else "cuda:" + args.gpu |
| | moral_emotion = args.emotion |
| | data_path = args.data_path |
| | language = args.language |
| | save_dir = os.path.join(args.save_dir, language) |
| |
|
| | model_path = os.path.join(language, f'multimodal_moral_emotion_classifier_{language[:3]}_{moral_emotion}.pth') |
| | save_path = os.path.join(save_dir, moral_emotion + '_results.json') |
| | korean = True if language == 'korean' else False |
| | threshold = KOREAN_MORAL_EMOTION_THRESHOLDDS[moral_emotion] if korean else ENGLISH_MORAL_EMOTION_THRESHOLDDS[moral_emotion] |
| |
|
| | if not os.path.exists(save_dir): os.makedirs(save_dir) |
| |
|
| | max_memory = {} |
| | if args.max_memory is not None: |
| | max_memory_pairs = args.max_memory.split(",") |
| | for pair in max_memory_pairs: |
| | gpu, mem = pair.split(":") |
| | max_memory[int(gpu)] = mem |
| |
|
| | load_dotenv(override=True) |
| | login_key = os.getenv("HUGGINGFACE_KEY") |
| | login(login_key) |
| |
|
| | test_dataset = load_dataset(data_path, split="test") |
| |
|
| | binary_labels = [ |
| | "False", |
| | "True" |
| | ] |
| |
|
| | model = MoralEmotionVLClassifier( |
| | model_name, |
| | num_labels=1, |
| | device=device, |
| | max_memory=max_memory, |
| | label_names=binary_labels |
| | ) |
| |
|
| | model.to(device) |
| | checkpoint = torch.load(model_path, map_location='cpu') |
| | model.load_state_dict(checkpoint, strict=False) |
| |
|
| | processor = AutoProcessor.from_pretrained(model_name) |
| | |
| | |
| | formatted_test_dataset = [format_data(sample, moral_emotion, korean=korean) for sample in test_dataset] |
| |
|
| | test_dataloader = DataLoader( |
| | formatted_test_dataset, |
| | batch_size=128, |
| | shuffle=False, |
| | collate_fn=lambda examples: collate_fn(examples, processor, device, model.label2id, train=False), |
| | ) |
| |
|
| | model.eval() |
| | outputs = predict(model, test_dataloader, threshold) |
| | with open(save_path, 'w') as f: |
| | json.dump(outputs, f, indent=4) |