BioMedGPT-Mol / evaluation /inference.py
leofansq's picture
update for evaluation
3824ea0 verified
import os
import json
from tqdm import tqdm
from utils.inference_utils import Agent, AgentMolLM, Extractor, canonicalize_smiles
class InferenceDataset():
def __init__(self, json_file, model_name="biomedgpt_mol", url="http://localhost:8000/v1"):
if model_name=="biomedgpt_mol":
self.agent = AgentMolLM(url=url)
else:
self.agent = Agent(url=url)
self.data = self._load_json(json_file)
print ("[Dataset] loaded")
def _load_json(self, file_path):
with open(file_path, 'r') as f:
contents = json.load(f)
return contents
def _dump_json(self, content, file_path):
with open(file_path, 'w') as f:
json.dump(content, f, indent=4, ensure_ascii=False)
def inference(self, save_file, temperature=0.01, top_p=0.01, n_resp=1, add_no_think=False):
print ("[Dataset Inference] Start...")
logs = list()
for sample in tqdm(self.data):
resp = self.agent.generate(query=sample['instruction'], temperature=temperature, top_p=top_p, n_resp=n_resp, add_no_think=add_no_think)
log = {
"query": sample['instruction'],
"answer": resp,
"gt": sample['output'],
"metadata": sample['metadata']
}
logs.append(log)
if len(logs) % 1000 == 0:
self._dump_json(logs, save_file)
self._dump_json(logs, save_file)
print (f"[Dataset Inference] Done. Save at {save_file}.")
def extract(self, log_file, save_file, task):
print ("[Log Extraction] Start...")
logs = self._load_json(log_file)
extractor = Extractor(task_type=task)
results = list()
for log in logs:
if isinstance(log['answer'], list):
candidates = list()
for candidate in log['answer']:
candidate = extractor.extract(candidate)
try:
candidate = canonicalize_smiles(candidate)
candidates.append(candidate)
except:
pass
log['extracted_answer'] = candidates
else:
log['extracted_answer'] = extractor.extract(log['answer'])
log['extracted_gt'] = extractor.extract(log['gt'])
results.append(log)
self._dump_json(results, save_file)
print (f"[Log Extraction] Done. Save at {save_file}.")
def inference_smolinstruct(dataset_path, save_dir="logs", model_name="biomedgpt_mol", url="http://localhost:8000/v1"):
save_dir = os.path.join(save_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_dir = os.path.join(save_dir, "smolinstruct")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_dir = os.path.join(save_dir, "logs")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
res_dir = os.path.join(save_dir, "results")
if not os.path.exists(res_dir):
os.makedirs(res_dir)
data_files = os.listdir(dataset_path)
for file_name in data_files:
if "smolinstruct" not in file_name:
continue
print (f"\n###### {file_name} ######")
workflow = InferenceDataset(json_file=os.path.join(dataset_path, file_name), url=url)
workflow.inference(save_file=os.path.join(log_dir, file_name), add_no_think=True)
if "molecule_captioning" in file_name:
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="text")
elif "s2i" in file_name:
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="iupac")
elif ("i2f" in file_name) or ("s2f" in file_name):
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="formula")
elif ("bbbp" in file_name) or ("clintox" in file_name) or ("hiv" in file_name) or ("sider" in file_name):
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="bool")
elif ("esol" in file_name) or ("lipo" in file_name):
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="value")
else:
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="smiles")
def inference_openmolinst(dataset_path, save_dir="logs", model_name="biomedgpt_mol", url="http://localhost:8000/v1", add_no_think=True):
save_dir = os.path.join(save_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_dir = os.path.join(save_dir, "openmolinst")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_dir = os.path.join(save_dir, "logs")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
res_dir = os.path.join(save_dir, "results")
if not os.path.exists(res_dir):
os.makedirs(res_dir)
if os.path.isfile(dataset_path):
file_name = os.path.basename(dataset_path)
workflow = InferenceDataset(json_file=dataset_path, url=url)
workflow.inference(save_file=os.path.join(log_dir, file_name), add_no_think=add_no_think)
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="smiles")
elif os.path.isdir(dataset_path):
data_files = os.listdir(dataset_path)
for file_name in data_files:
if "openmolinst" not in file_name:
continue
print (f"\n###### {file_name} ######")
workflow = InferenceDataset(json_file=os.path.join(dataset_path, file_name), url=url)
workflow.inference(save_file=os.path.join(log_dir, file_name), add_no_think=add_no_think)
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="smiles")
def inference_mumoinstruct(dataset_path, save_dir="logs", model_name="biomedgpt_mol", url="http://localhost:8000/v1"):
save_dir = os.path.join(save_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_dir = os.path.join(save_dir, "mumoinstruct")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
log_dir = os.path.join(save_dir, "logs")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
res_dir = os.path.join(save_dir, "results")
if not os.path.exists(res_dir):
os.makedirs(res_dir)
data_files = os.listdir(dataset_path)
for file_name in data_files:
if "mumoinstruct" not in file_name:
continue
print (f"\n###### {file_name} ######")
workflow = InferenceDataset(json_file=os.path.join(dataset_path, file_name), url=url)
# beam search, num = 20
workflow.inference(save_file=os.path.join(log_dir, file_name),
n_resp=20,
add_no_think=True)
workflow.extract(log_file=os.path.join(log_dir, file_name),
save_file=os.path.join(res_dir, file_name),
task="smiles")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="")
parser.add_argument("--dataset_name", type=str, help="name of the dataset")
parser.add_argument("--dataset_path", type=str, help="path to dataset files")
parser.add_argument("--save_dir", type=str, default="logs", help="path to log files")
parser.add_argument("--model_name", type=str, default="biomedgpt_mol", help="name of the model")
parser.add_argument('--disable_no_think', action='store_false', help='let the model think')
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="url of the API")
args = parser.parse_args()
if args.dataset_name == "smolinstruct":
inference_smolinstruct(dataset_path=args.dataset_path,
save_dir=args.save_dir,
model_name=args.model_name,
url=args.url)
elif args.dataset_name == "openmolinst":
inference_openmolinst(dataset_path=args.dataset_path,
save_dir=args.save_dir,
model_name=args.model_name,
url=args.url,
add_no_think=args.disable_no_think)
elif args.dataset_name == "mumoinstruct":
inference_mumoinstruct(dataset_path=args.dataset_path,
save_dir=args.save_dir,
model_name=args.model_name,
url=args.url)