from absl import logging import numpy as np from PIL import Image import torch from torch.distributions.categorical import Categorical from rdkit import RDLogger import re RDLogger.DisableLog('rdApp.*') @torch.no_grad() def AE_SMILES_encoder(sm, ae_model): if sm[0][:5] == "[CLS]": sm = [s[5:] for s in sm] text_input = ae_model.tokenizer(sm).to(ae_model.device) text_input_ids = text_input text_attention_mask = torch.where(text_input_ids == 0, 0, 1).to(text_input.device) if hasattr(ae_model.text_encoder2, 'bert'): output = ae_model.text_encoder2.bert(text_input_ids, attention_mask=text_attention_mask, return_dict=True, mode='text').last_hidden_state else: output = ae_model.text_encoder2(text_input_ids, attention_mask=text_attention_mask, return_dict=True).last_hidden_state if hasattr(ae_model, 'encode_prefix'): output = ae_model.encode_prefix(output) if ae_model.output_dim*2 == output.size(-1): mean, logvar = torch.chunk(output, 2, dim=-1) logvar = torch.clamp(logvar, -30.0, 20.0) std = torch.exp(0.5 * logvar) output = mean + std * torch.randn_like(mean) return output @torch.no_grad() def generate(model, image_embeds, text, stochastic=True, prop_att_mask=None, k=None): text_atts = torch.where(text == 0, 0, 1) if prop_att_mask is None: prop_att_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) token_output = model.text_encoder(text, attention_mask=text_atts, encoder_hidden_states=image_embeds, encoder_attention_mask=prop_att_mask, return_dict=True, is_decoder=True, return_logits=True, )[:, -1, :] # batch*300 if k: p = torch.softmax(token_output, dim=-1) if stochastic: output = torch.multinomial(p, num_samples=k, replacement=False) return torch.log(torch.stack([p[i][output[i]] for i in range(output.size(0))])), output else: output = torch.topk(p, k=k, dim=-1) # batch*k return torch.log(output.values), output.indices if stochastic: p = torch.softmax(token_output, dim=-1) m = Categorical(p) token_output = m.sample() else: token_output = torch.argmax(token_output, dim=-1) return token_output.unsqueeze(1) # batch*1 @torch.no_grad() def AE_SMILES_decoder(pv, model, stochastic=False, k=2, max_length=150): if hasattr(model, 'decode_prefix'): pv = model.decode_prefix(pv) tokenizer = model.tokenizer if tokenizer is None: raise ValueError('Tokenizer is not defined') # test model.eval() candidate = [] if k == 1: text_input = torch.tensor([tokenizer.cls_token_id]).expand(pv.size(0), 1).to(model.device) # batch*1 for _ in range(max_length): output = generate(model, pv, text_input, stochastic=False) if output.sum() == 0: break text_input = torch.cat([text_input, output], dim=-1) for i in range(text_input.size(0)): sentence = text_input[i] cdd = tokenizer.decode(sentence)[0]#newtkn candidate.append(cdd) else: for prop_embeds in pv: prop_embeds = prop_embeds.unsqueeze(0) product_input = torch.tensor([tokenizer.cls_token_id]).expand(1, 1).to(model.device) values, indices = generate(model, prop_embeds, product_input, stochastic=stochastic, k=k) product_input = torch.cat([torch.tensor([tokenizer.cls_token_id]).expand(k, 1).to(model.device), indices.squeeze(0).unsqueeze(-1)], dim=-1) current_p = values.squeeze(0) final_output = [] for _ in range(max_length): values, indices = generate(model, prop_embeds, product_input, stochastic=stochastic, k=k) k2_p = current_p[:, None] + values product_input_k2 = torch.cat([product_input.unsqueeze(1).repeat(1, k, 1), indices.unsqueeze(-1)], dim=-1) if tokenizer.sep_token_id in indices: ends = (indices == tokenizer.sep_token_id).nonzero(as_tuple=False) for e in ends: p = k2_p[e[0], e[1]].cpu().item() final_output.append((p, product_input_k2[e[0], e[1]])) k2_p[e[0], e[1]] = -1e5 if len(final_output) >= k ** 1: break current_p, i = torch.topk(k2_p.flatten(), k) next_indices = torch.from_numpy(np.array(np.unravel_index(i.cpu().numpy(), k2_p.shape))).T product_input = torch.stack([product_input_k2[i[0], i[1]] for i in next_indices], dim=0) candidate_k = [] final_output = sorted(final_output, key=lambda x: x[0], reverse=True)[:k] for p, sentence in final_output: cdd = tokenizer.decode(sentence[:-1])[0]#newtkn candidate_k.append(cdd) if candidate_k == []: candidate.append("") else: candidate.append(candidate_k[0]) # candidate.append(random.choice(candidate_k)) return candidate @torch.no_grad() def molT5_encoder(descriptions, molt5, molt5_tokenizer, description_length, device): tokenized = molt5_tokenizer(descriptions, padding='max_length', truncation=True, max_length=description_length, return_tensors="pt").to(device) encoder_outputs = molt5.encoder(input_ids=tokenized.input_ids, attention_mask=tokenized.attention_mask, return_dict=True).last_hidden_state return encoder_outputs, tokenized.attention_mask def get_validity(smiles): from rdkit import Chem v = [] for l in smiles: try: if l == "": continue s = Chem.MolToSmiles(Chem.MolFromSmiles(l), isomericSmiles=False) v.append(s) except: continue u = list(set(v)) if len(u) == 0: return 0., 0. return len(v) / len(smiles) alphabets = "([A-Za-z])" prefixes = "(Mr|St|Mrs|Ms|Dr)[.]" suffixes = "(Inc|Ltd|Jr|Sr|Co)" starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" websites = "[.](com|net|org|io|gov|edu|me)" digits = "([0-9])" multiple_dots = r'\.{2,}' def split_into_sentences(text: str) -> list[str]: """ Split the text into sentences. If the text contains substrings "" or "", they would lead to incorrect splitting because they are used as markers for splitting. :param text: text to be split into sentences :type text: str :return: list of sentences :rtype: list[str] """ text = " " + text + " " text = text.replace("\n", " ") text = re.sub(prefixes, "\\1", text) text = re.sub(websites, "\\1", text) text = re.sub(digits + "[.]" + digits, "\\1\\2", text) text = re.sub(multiple_dots, lambda match: "" * len(match.group(0)) + "", text) if "Ph.D" in text: text = text.replace("Ph.D.", "PhD") text = re.sub("\s" + alphabets + "[.] ", " \\1 ", text) text = re.sub(acronyms + " " + starters, "\\1 \\2", text) text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]", "\\1\\2\\3", text) text = re.sub(alphabets + "[.]" + alphabets + "[.]", "\\1\\2", text) text = re.sub(" " + suffixes + "[.] " + starters, " \\1 \\2", text) text = re.sub(" " + suffixes + "[.]", " \\1", text) text = re.sub(" " + alphabets + "[.]", " \\1", text) if "”" in text: text = text.replace(".”", "”.") if "\"" in text: text = text.replace(".\"", "\".") if "!" in text: text = text.replace("!\"", "\"!") if "?" in text: text = text.replace("?\"", "\"?") text = text.replace(".", ".") text = text.replace("?", "?") text = text.replace("!", "!") text = text.replace("", ".") sentences = text.split("") sentences = [s.strip() for s in sentences] if sentences and not sentences[-1]: sentences = sentences[:-1] return sentences def center_crop(width, height, img): resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos'] crop = np.min(img.shape[:2]) img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] # center crop try: img = Image.fromarray(img, 'RGB') except: img = Image.fromarray(img) img = img.resize((width, height), resample) # resize the center crop from [crop, crop] to [width, height] return np.array(img).astype(np.uint8) def set_logger(log_level='info', fname=None): import logging as _logging handler = logging.get_absl_handler() formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') handler.setFormatter(formatter) logging.set_verbosity(log_level) if fname is not None: handler = _logging.FileHandler(fname) handler.setFormatter(formatter) logging.get_absl_logger().addHandler(handler) def drawRoundRec(draw, color, x, y, w, h, r): drawObject = draw '''Rounds''' drawObject.ellipse((x, y, x + r, y + r), fill=color) drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color) drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color) drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color) '''rec.s''' drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color) drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color) class regexTokenizer(): def __init__(self,vocab_path='./vocab_bpe_300_sc.txt',max_len=127): with open(vocab_path,'r') as f: x = f.readlines() x = [xx.replace('##', '') for xx in x] x2 = x.copy() x2.sort(key=len, reverse=True) pattern = "("+"|".join(re.escape(token).strip()[:-1] for token in x2)+")" self.rg = re.compile(pattern) self.idtotok = { cnt:i.strip() for cnt,i in enumerate(x)} self.vocab_size = len(self.idtotok) #SOS, EOS, pad self.toktoid = { v:k for k,v in self.idtotok.items()} self.max_len = max_len self.cls_token_id = self.toktoid['[CLS]'] self.sep_token_id = self.toktoid['[SEP]'] self.pad_token_id = self.toktoid['[PAD]'] def decode_one(self, iter): if self.sep_token_id in iter: iter = iter[:(iter == self.sep_token_id).nonzero(as_tuple=True)[0][0].item()] # return "".join([self.ind2Letter(i) for i in iter]).replace('[SOS]','').replace('[EOS]','').replace('[PAD]','') return "".join([self.idtotok[i.item()] for i in iter[1:]]) def decode(self,ids:torch.tensor): if len(ids.shape)==1: return [self.decode_one(ids)] else: smiles = [] for i in ids: smiles.append(self.decode_one(i)) return smiles def __len__(self): return self.vocab_size def __call__(self,smis:list, truncation='max_len'): tensors = [] lengths = [] if type(smis) is str: smis = [smis] for i in smis: length, tensor = self.encode_one(i) tensors.append(tensor) lengths.append(length) output = torch.concat(tensors,dim=0) if truncation == 'max_len': return output elif truncation == 'longest': return output[:, :max(lengths)] else: raise ValueError('truncation should be either max_len or longest') def encode_one(self, smi): smi = '[CLS]' + smi + '[SEP]' res = [self.toktoid[i] for i in self.rg.findall(smi)] token_length = len(res) if token_length < self.max_len: res += [self.pad_token_id]*(self.max_len-len(res)) else: res = res[:self.max_len] # res[-1] = self.sep_token_id return token_length, torch.LongTensor([res])