BioMedGPT-Mol / evaluation /utils /inference_utils.py
leofansq's picture
update for evaluation
3824ea0 verified
import os
import re
import json
from openai import OpenAI
from tqdm import tqdm
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
def canonicalize_smiles(smiles):
mol = Chem.MolFromSmiles(smiles)
return Chem.MolToSmiles(mol, canonical=True)
class Agent():
def __init__(self, url="http://localhost:8000/v1"):
self.client = OpenAI(
api_key="0",
base_url=url
)
print ("[Client Init] done")
def generate(self, query):
messages = [
{
"role": "user",
"content": query
}
]
result = self.client.chat.completions.create(
messages=messages,
model="general")
return result.choices[0].message.content
class AgentMolLM(Agent):
def __init__(self, url="http://localhost:8000/v1"):
super(AgentMolLM, self).__init__(url=url)
def generate(self, query, temperature=0.01, top_p=0.01, n_resp=1, add_no_think=False):
query = query+" /no_think" if add_no_think else query
messages = [
{
"role": "user",
"content": query
}
]
result = self.client.chat.completions.create(
messages=messages,
model="mol_lm",
temperature=temperature,
top_p=top_p,
n=n_resp)
if n_resp>1:
results = list()
for i in range(n_resp):
results.append(result.choices[i].message.content)
return results
else:
return result.choices[0].message.content
class ExtractorSMILES():
def __init__(self):
self.pattern = r"<SMILES>(.*?)</SMILES>"
def extract(self, text):
res = list()
for smi in re.findall(self.pattern, text, re.DOTALL):
res.append(smi.replace(";", "."))
return res
class ExtractorFormula():
def __init__(self):
self.pattern = r"<MOLFORMULA>(.*?)</MOLFORMULA>"
def extract(self, text):
return re.findall(self.pattern, text, re.DOTALL)
class ExtractorIUPAC():
def __init__(self):
self.pattern = r"<IUPAC>(.*?)</IUPAC>"
def extract(self, text):
return re.findall(self.pattern, text, re.DOTALL)
class ExtractorValue():
def __init__(self):
self.pattern = r"[+-]?\d*\.\d+"
def extract(self, text):
return re.findall(self.pattern, text)
class ExtractorYN():
def __init__(self):
pass
def extract(self, text):
if "yes" in text.lower():
return ["yes"]
elif "no" in text.lower():
return ["no"]
else:
return ["N/A"]
class ExtractorText():
def __init__(self):
pass
def extract(self, text):
return [text]
class Extractor():
def __init__(self, task_type="smiles"):
if task_type == "smiles":
self.extractor = ExtractorSMILES()
elif task_type == "formula":
self.extractor = ExtractorFormula()
elif task_type == "iupac":
self.extractor = ExtractorIUPAC()
elif task_type == "value":
self.extractor = ExtractorValue()
elif task_type == "bool":
self.extractor = ExtractorYN()
else:
self.extractor = ExtractorText()
def extract(self, text):
text = text.split("</think>\n")[-1]
resp = self.extractor.extract(text=text)
res = resp[0].strip() if len(resp) else ""
return res
if __name__ == "__main__":
agent = AgentMolLM()
QUERY = "你是谁"
ADD_NO_THINK = True
# resp = agent.generate(QUERY, add_no_think=ADD_NO_THINK)
resp = agent.generate(QUERY, add_no_think=ADD_NO_THINK, temperature=0.6, top_p=0.95, n_resp=2)
print (resp)