import os import re import json from openai import OpenAI from tqdm import tqdm from rdkit import Chem from rdkit import RDLogger RDLogger.DisableLog('rdApp.*') def canonicalize_smiles(smiles): mol = Chem.MolFromSmiles(smiles) return Chem.MolToSmiles(mol, canonical=True) class Agent(): def __init__(self, url="http://localhost:8000/v1"): self.client = OpenAI( api_key="0", base_url=url ) print ("[Client Init] done") def generate(self, query): messages = [ { "role": "user", "content": query } ] result = self.client.chat.completions.create( messages=messages, model="general") return result.choices[0].message.content class AgentMolLM(Agent): def __init__(self, url="http://localhost:8000/v1"): super(AgentMolLM, self).__init__(url=url) def generate(self, query, temperature=0.01, top_p=0.01, n_resp=1, add_no_think=False): query = query+" /no_think" if add_no_think else query messages = [ { "role": "user", "content": query } ] result = self.client.chat.completions.create( messages=messages, model="mol_lm", temperature=temperature, top_p=top_p, n=n_resp) if n_resp>1: results = list() for i in range(n_resp): results.append(result.choices[i].message.content) return results else: return result.choices[0].message.content class ExtractorSMILES(): def __init__(self): self.pattern = r"(.*?)" def extract(self, text): res = list() for smi in re.findall(self.pattern, text, re.DOTALL): res.append(smi.replace(";", ".")) return res class ExtractorFormula(): def __init__(self): self.pattern = r"(.*?)" def extract(self, text): return re.findall(self.pattern, text, re.DOTALL) class ExtractorIUPAC(): def __init__(self): self.pattern = r"(.*?)" def extract(self, text): return re.findall(self.pattern, text, re.DOTALL) class ExtractorValue(): def __init__(self): self.pattern = r"[+-]?\d*\.\d+" def extract(self, text): return re.findall(self.pattern, text) class ExtractorYN(): def __init__(self): pass def extract(self, text): if "yes" in text.lower(): return ["yes"] elif "no" in text.lower(): return ["no"] else: return ["N/A"] class ExtractorText(): def __init__(self): pass def extract(self, text): return [text] class Extractor(): def __init__(self, task_type="smiles"): if task_type == "smiles": self.extractor = ExtractorSMILES() elif task_type == "formula": self.extractor = ExtractorFormula() elif task_type == "iupac": self.extractor = ExtractorIUPAC() elif task_type == "value": self.extractor = ExtractorValue() elif task_type == "bool": self.extractor = ExtractorYN() else: self.extractor = ExtractorText() def extract(self, text): text = text.split("\n")[-1] resp = self.extractor.extract(text=text) res = resp[0].strip() if len(resp) else "" return res if __name__ == "__main__": agent = AgentMolLM() QUERY = "你是谁" ADD_NO_THINK = True # resp = agent.generate(QUERY, add_no_think=ADD_NO_THINK) resp = agent.generate(QUERY, add_no_think=ADD_NO_THINK, temperature=0.6, top_p=0.95, n_resp=2) print (resp)