|
|
import os |
|
|
import re |
|
|
import json |
|
|
from openai import OpenAI |
|
|
from tqdm import tqdm |
|
|
|
|
|
from rdkit import Chem |
|
|
from rdkit import RDLogger |
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
|
|
def canonicalize_smiles(smiles): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
return Chem.MolToSmiles(mol, canonical=True) |
|
|
|
|
|
|
|
|
class Agent(): |
|
|
def __init__(self, url="http://localhost:8000/v1"): |
|
|
|
|
|
self.client = OpenAI( |
|
|
api_key="0", |
|
|
base_url=url |
|
|
) |
|
|
print ("[Client Init] done") |
|
|
|
|
|
def generate(self, query): |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": query |
|
|
} |
|
|
] |
|
|
result = self.client.chat.completions.create( |
|
|
messages=messages, |
|
|
model="general") |
|
|
|
|
|
return result.choices[0].message.content |
|
|
|
|
|
class AgentMolLM(Agent): |
|
|
|
|
|
def __init__(self, url="http://localhost:8000/v1"): |
|
|
|
|
|
super(AgentMolLM, self).__init__(url=url) |
|
|
|
|
|
def generate(self, query, temperature=0.01, top_p=0.01, n_resp=1, add_no_think=False): |
|
|
|
|
|
query = query+" /no_think" if add_no_think else query |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": query |
|
|
} |
|
|
] |
|
|
result = self.client.chat.completions.create( |
|
|
messages=messages, |
|
|
model="mol_lm", |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
n=n_resp) |
|
|
|
|
|
if n_resp>1: |
|
|
results = list() |
|
|
for i in range(n_resp): |
|
|
results.append(result.choices[i].message.content) |
|
|
return results |
|
|
else: |
|
|
return result.choices[0].message.content |
|
|
|
|
|
|
|
|
class ExtractorSMILES(): |
|
|
def __init__(self): |
|
|
self.pattern = r"<SMILES>(.*?)</SMILES>" |
|
|
def extract(self, text): |
|
|
res = list() |
|
|
for smi in re.findall(self.pattern, text, re.DOTALL): |
|
|
res.append(smi.replace(";", ".")) |
|
|
return res |
|
|
|
|
|
class ExtractorFormula(): |
|
|
def __init__(self): |
|
|
self.pattern = r"<MOLFORMULA>(.*?)</MOLFORMULA>" |
|
|
def extract(self, text): |
|
|
return re.findall(self.pattern, text, re.DOTALL) |
|
|
|
|
|
class ExtractorIUPAC(): |
|
|
def __init__(self): |
|
|
self.pattern = r"<IUPAC>(.*?)</IUPAC>" |
|
|
def extract(self, text): |
|
|
return re.findall(self.pattern, text, re.DOTALL) |
|
|
|
|
|
class ExtractorValue(): |
|
|
def __init__(self): |
|
|
self.pattern = r"[+-]?\d*\.\d+" |
|
|
def extract(self, text): |
|
|
return re.findall(self.pattern, text) |
|
|
|
|
|
class ExtractorYN(): |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def extract(self, text): |
|
|
if "yes" in text.lower(): |
|
|
return ["yes"] |
|
|
elif "no" in text.lower(): |
|
|
return ["no"] |
|
|
else: |
|
|
return ["N/A"] |
|
|
|
|
|
class ExtractorText(): |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def extract(self, text): |
|
|
return [text] |
|
|
|
|
|
class Extractor(): |
|
|
def __init__(self, task_type="smiles"): |
|
|
if task_type == "smiles": |
|
|
self.extractor = ExtractorSMILES() |
|
|
elif task_type == "formula": |
|
|
self.extractor = ExtractorFormula() |
|
|
elif task_type == "iupac": |
|
|
self.extractor = ExtractorIUPAC() |
|
|
elif task_type == "value": |
|
|
self.extractor = ExtractorValue() |
|
|
elif task_type == "bool": |
|
|
self.extractor = ExtractorYN() |
|
|
else: |
|
|
self.extractor = ExtractorText() |
|
|
|
|
|
def extract(self, text): |
|
|
text = text.split("</think>\n")[-1] |
|
|
resp = self.extractor.extract(text=text) |
|
|
res = resp[0].strip() if len(resp) else "" |
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
agent = AgentMolLM() |
|
|
|
|
|
QUERY = "你是谁" |
|
|
ADD_NO_THINK = True |
|
|
|
|
|
|
|
|
resp = agent.generate(QUERY, add_no_think=ADD_NO_THINK, temperature=0.6, top_p=0.95, n_resp=2) |
|
|
|
|
|
print (resp) |