| | from vllm import LLM, SamplingParams |
| | import multiprocessing |
| | import time |
| | import gc |
| | import torch |
| | import pdb |
| | import sqlite3 |
| | from concurrent.futures import ThreadPoolExecutor |
| | |
| | def label_transform(label): |
| | if label==1: |
| | return 'neutral' |
| | if label==0: |
| | return 'entailment' |
| | if label==2: |
| | return 'contradiction' |
| | sampling_params = SamplingParams(temperature=0.0,max_tokens=600, top_p=0.95) |
| | def valid_results_collect(model_path,valid_data,task): |
| | |
| | torch.cuda.empty_cache() |
| | torch.cuda.ipc_collect() |
| | |
| | trained_model=LLM(model=model_path,gpu_memory_utilization=0.95) |
| | start_t=time.time() |
| | failed_cases,correct_cases=nli_evaluation(trained_model,valid_data) |
| | del trained_model |
| | end_t=time.time() |
| | print('time',start_t-end_t) |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | torch.cuda.ipc_collect() |
| | torch.cuda.synchronize() |
| | time.sleep(10) |
| | return failed_cases,correct_cases |
| | def extract_answer_prediction_nli(predicted_output): |
| | sens=predicted_output.split('.') |
| | final_sens=[sen for sen in sens if 'final' in sen] |
| | for sen in final_sens: |
| | if extract_answer(sen): |
| | return extract_answer(sen) |
| | return |
| | def extract_answer(text): |
| | if 'neutral' in text.lower(): |
| | return 'neutral' |
| | if 'contradiction' in text.lower(): |
| | return 'contradiction' |
| | if 'entailment' in text.lower(): |
| | return 'entailment' |
| | return None |
| | def process_batch(data_batch,trained_model,failed_cases,correct_cases): |
| | batch_prompts = [data['Input'] for data in data_batch] |
| | outputs = trained_model.generate(batch_prompts, sampling_params) |
| | |
| | labels=['entailment','contradiction','neutral'] |
| | for data, output in zip(data_batch, outputs): |
| | |
| | predicted_output = output.outputs[0].text |
| | predicted_res = extract_answer_prediction_nli(predicted_output) |
| | label = extract_answer(data['Output'].split('is')[-1]) |
| | print(predicted_res,label,'\n') |
| | if not predicted_res: |
| | |
| | |
| | predicted_res=predicted_output |
| | non_labels = [lbl for lbl in labels if lbl != label] |
| | if label not in predicted_res or any(non_label in predicted_res for non_label in non_labels): |
| | failed_cases.append((data['Input'],predicted_res,label,data)) |
| | else: |
| | correct_cases.append((data['Input'],predicted_res,label,data)) |
| | return failed_cases,correct_cases |
| | def nli_evaluation(trained_model,valid_data): |
| | failed_cases=[] |
| | correct_cases=[] |
| | batch_size=500 |
| | batched_data = [valid_data[i:i+batch_size] for i in range(0, len(valid_data), batch_size)] |
| | for batch in batched_data: |
| | failed_cases,correct_cases=process_batch(batch,trained_model,failed_cases,correct_cases) |
| | return failed_cases,correct_cases |