import os import json from tqdm import tqdm from utils.inference_utils import Agent, AgentMolLM, Extractor, canonicalize_smiles class InferenceDataset(): def __init__(self, json_file, model_name="biomedgpt_mol", url="http://localhost:8000/v1"): if model_name=="biomedgpt_mol": self.agent = AgentMolLM(url=url) else: self.agent = Agent(url=url) self.data = self._load_json(json_file) print ("[Dataset] loaded") def _load_json(self, file_path): with open(file_path, 'r') as f: contents = json.load(f) return contents def _dump_json(self, content, file_path): with open(file_path, 'w') as f: json.dump(content, f, indent=4, ensure_ascii=False) def inference(self, save_file, temperature=0.01, top_p=0.01, n_resp=1, add_no_think=False): print ("[Dataset Inference] Start...") logs = list() for sample in tqdm(self.data): resp = self.agent.generate(query=sample['instruction'], temperature=temperature, top_p=top_p, n_resp=n_resp, add_no_think=add_no_think) log = { "query": sample['instruction'], "answer": resp, "gt": sample['output'], "metadata": sample['metadata'] } logs.append(log) if len(logs) % 1000 == 0: self._dump_json(logs, save_file) self._dump_json(logs, save_file) print (f"[Dataset Inference] Done. Save at {save_file}.") def extract(self, log_file, save_file, task): print ("[Log Extraction] Start...") logs = self._load_json(log_file) extractor = Extractor(task_type=task) results = list() for log in logs: if isinstance(log['answer'], list): candidates = list() for candidate in log['answer']: candidate = extractor.extract(candidate) try: candidate = canonicalize_smiles(candidate) candidates.append(candidate) except: pass log['extracted_answer'] = candidates else: log['extracted_answer'] = extractor.extract(log['answer']) log['extracted_gt'] = extractor.extract(log['gt']) results.append(log) self._dump_json(results, save_file) print (f"[Log Extraction] Done. Save at {save_file}.") def inference_smolinstruct(dataset_path, save_dir="logs", model_name="biomedgpt_mol", url="http://localhost:8000/v1"): save_dir = os.path.join(save_dir, model_name) if not os.path.exists(save_dir): os.makedirs(save_dir) save_dir = os.path.join(save_dir, "smolinstruct") if not os.path.exists(save_dir): os.makedirs(save_dir) log_dir = os.path.join(save_dir, "logs") if not os.path.exists(log_dir): os.makedirs(log_dir) res_dir = os.path.join(save_dir, "results") if not os.path.exists(res_dir): os.makedirs(res_dir) data_files = os.listdir(dataset_path) for file_name in data_files: if "smolinstruct" not in file_name: continue print (f"\n###### {file_name} ######") workflow = InferenceDataset(json_file=os.path.join(dataset_path, file_name), url=url) workflow.inference(save_file=os.path.join(log_dir, file_name), add_no_think=True) if "molecule_captioning" in file_name: workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="text") elif "s2i" in file_name: workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="iupac") elif ("i2f" in file_name) or ("s2f" in file_name): workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="formula") elif ("bbbp" in file_name) or ("clintox" in file_name) or ("hiv" in file_name) or ("sider" in file_name): workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="bool") elif ("esol" in file_name) or ("lipo" in file_name): workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="value") else: workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="smiles") def inference_openmolinst(dataset_path, save_dir="logs", model_name="biomedgpt_mol", url="http://localhost:8000/v1", add_no_think=True): save_dir = os.path.join(save_dir, model_name) if not os.path.exists(save_dir): os.makedirs(save_dir) save_dir = os.path.join(save_dir, "openmolinst") if not os.path.exists(save_dir): os.makedirs(save_dir) log_dir = os.path.join(save_dir, "logs") if not os.path.exists(log_dir): os.makedirs(log_dir) res_dir = os.path.join(save_dir, "results") if not os.path.exists(res_dir): os.makedirs(res_dir) if os.path.isfile(dataset_path): file_name = os.path.basename(dataset_path) workflow = InferenceDataset(json_file=dataset_path, url=url) workflow.inference(save_file=os.path.join(log_dir, file_name), add_no_think=add_no_think) workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="smiles") elif os.path.isdir(dataset_path): data_files = os.listdir(dataset_path) for file_name in data_files: if "openmolinst" not in file_name: continue print (f"\n###### {file_name} ######") workflow = InferenceDataset(json_file=os.path.join(dataset_path, file_name), url=url) workflow.inference(save_file=os.path.join(log_dir, file_name), add_no_think=add_no_think) workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="smiles") def inference_mumoinstruct(dataset_path, save_dir="logs", model_name="biomedgpt_mol", url="http://localhost:8000/v1"): save_dir = os.path.join(save_dir, model_name) if not os.path.exists(save_dir): os.makedirs(save_dir) save_dir = os.path.join(save_dir, "mumoinstruct") if not os.path.exists(save_dir): os.makedirs(save_dir) log_dir = os.path.join(save_dir, "logs") if not os.path.exists(log_dir): os.makedirs(log_dir) res_dir = os.path.join(save_dir, "results") if not os.path.exists(res_dir): os.makedirs(res_dir) data_files = os.listdir(dataset_path) for file_name in data_files: if "mumoinstruct" not in file_name: continue print (f"\n###### {file_name} ######") workflow = InferenceDataset(json_file=os.path.join(dataset_path, file_name), url=url) # beam search, num = 20 workflow.inference(save_file=os.path.join(log_dir, file_name), n_resp=20, add_no_think=True) workflow.extract(log_file=os.path.join(log_dir, file_name), save_file=os.path.join(res_dir, file_name), task="smiles") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="") parser.add_argument("--dataset_name", type=str, help="name of the dataset") parser.add_argument("--dataset_path", type=str, help="path to dataset files") parser.add_argument("--save_dir", type=str, default="logs", help="path to log files") parser.add_argument("--model_name", type=str, default="biomedgpt_mol", help="name of the model") parser.add_argument('--disable_no_think', action='store_false', help='let the model think') parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="url of the API") args = parser.parse_args() if args.dataset_name == "smolinstruct": inference_smolinstruct(dataset_path=args.dataset_path, save_dir=args.save_dir, model_name=args.model_name, url=args.url) elif args.dataset_name == "openmolinst": inference_openmolinst(dataset_path=args.dataset_path, save_dir=args.save_dir, model_name=args.model_name, url=args.url, add_no_think=args.disable_no_think) elif args.dataset_name == "mumoinstruct": inference_mumoinstruct(dataset_path=args.dataset_path, save_dir=args.save_dir, model_name=args.model_name, url=args.url)