Spaces:
Sleeping
Sleeping
File size: 6,861 Bytes
79cdf7b 2df9869 d6fdf05 2df9869 14e32f4 2df9869 14e32f4 2df9869 f41438b 2df9869 3541985 d38f766 2df9869 d7d65ed 2df9869 17b0e42 26cbe4e 17b0e42 2df9869 17b0e42 2df9869 f41438b 2df9869 f41438b 2df9869 7a02397 2df9869 36f8047 98b1ab2 14e32f4 17b0e42 2df9869 98b1ab2 2df9869 79cdf7b 2df9869 79cdf7b 825e5af 2df9869 79cdf7b 413d237 79cdf7b 2df9869 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import time
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
import subprocess
from typing import List
from fragment_processor import fragmentize_molecule
from fastapi.responses import StreamingResponse
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors, QED, Draw
from io import BytesIO
from PIL import Image
import io
from generate import GenerateRunner
from dataset import Dataset
from combine_mol import connect_constVar_try
import sascorer
app = FastAPI()
class Fragment(BaseModel):
variable_smiles: str
constant_smiles: str
record_id: str
normalized_smiles: str
attachment_order: int
class FragmentResponse(BaseModel):
fragments: List[Fragment]
class GenerateRequest(BaseModel):
constSmiles: str
varSmiles: str
mainCls: str
minorCls: str
deltaValue: str
num: int
class MoleculeOutput(BaseModel):
smile: str
molwt: float
tpsa: float
slogp: float
sa: float
qed: float
class Options:
def __init__(self, **entries):
self.__dict__.update(entries)
def calculate_descriptors(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
# molwt = Descriptors.MolWt(mol)
# tpsa = Descriptors.TPSA(mol)
# slogp = Descriptors.MolLogP(mol)
# sa = sascorer.calculateScore(mol)
# qed = QED.qed(mol)
molwt = round(Descriptors.MolWt(mol), 1) # 保留 1 位小数
tpsa = round(Descriptors.TPSA(mol), 1)
slogp = round(Descriptors.MolLogP(mol), 1)
sa = round(sascorer.calculateScore(mol), 1)
qed = round(QED.qed(mol), 1)
# 检查除法前是否为 0
if tpsa == 0:
print("Warning: TPSA is zero, skipping division.")
some_ratio = None
else:
some_ratio = molwt / tpsa # 安全的除法操作
return {"molwt": molwt, "tpsa": tpsa, "slogp": slogp, "sa": sa, "qed": qed}
def run_generate_runner(const_smiles, var_smiles, main_cls, minor_cls, delta_value, num_samples):
# 初始化生成器的配置选项
opt = {
'batch_size': num_samples,
'data_path' : './',
'decode_type' : 'multinomial',
'dev_no' : 0,
'epoch' : 20,
'model_choice' : 'transformer',
'model_path' : './raw_pretrain_frag/checkpoint',
'num_samples' : 50,
'overwrite' : True,
# 'save_directory' : './demo_gen',
'test_file_name' : 'test_cut',
'vocab_path' : './'
}
# 将 opt 字典转换为 Options 对象
opt = Options(**opt)
print("--------------opt---------------")
print(opt)
runner = GenerateRunner(opt)
# 创建数据
data = {
"constantSMILES": const_smiles,
"fromVarSMILES": var_smiles,
"main_cls": main_cls,
"minor_cls": minor_cls,
"Delta_Value": delta_value
}
# 创建 Dataset 实例
test_data = pd.DataFrame([data])
dataset = Dataset(test_data, vocabulary=runner.vocab, tokenizer=runner.tokenizer, prediction_mode=True)
# 生成 SMILES
dataloader = torch.utils.data.DataLoader(dataset, batch_size=num_samples, shuffle=False, collate_fn=Dataset.collate_fn)
result = []
for batch in dataloader:
src, source_length, _, src_mask, _, _, df = batch
src = src.to(runner.device)
src_mask = src_mask.to(runner.device)
source_length = source_length.to(runner.device) # 将 source_length 也移到同一设备
smiles_list = runner.sample(
model_choice="transformer",
model=runner.model,
src=src,
src_mask=src_mask,
source_length=source_length,
decode_type="multinomial",
num_samples=num_samples
)
# 计算每个 SMILES 的化学性质
for smiles_group in smiles_list:
for smile in smiles_group: # smiles_group 是一个子列表
#链接新分子
newsmile=connect_constVar_try(const_smiles,smile)
descriptors = calculate_descriptors(newsmile)
if descriptors:
result.append({
"smile": newsmile,
"molwt": descriptors['molwt'],
"tpsa": descriptors['tpsa'],
"slogp": descriptors['slogp'],
"sa": descriptors['sa'],
"qed": descriptors['qed']
})
return result
@app.get("/smiles2img")
async def smiles2img(smiles: str = Query(..., description="SMILES string of the molecule")):
# 生成分子对象
print("---开始生成分子图像--")
print(smiles)
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return {"error": "Invalid SMILES string"}
# 创建绘图对象
drawer = Draw.MolDraw2DCairo(200, 200)
drawer.SetFontSize(1.0)
drawer.DrawMolecule(mol)
drawer.FinishDrawing()
# 将绘制的图像转换为PIL图像对象
pil_image = Image.open(io.BytesIO(drawer.GetDrawingText()))
# 创建字节流
strIO = BytesIO()
pil_image.save(strIO, "PNG")
strIO.seek(0)
# 返回图像作为流
return StreamingResponse(strIO, media_type="image/png")
@app.get("/fragmentize", response_model=FragmentResponse)
async def fragmentize(smiles: str = Query(..., description="SMILES string of the molecule")):
try:
fragment_df = fragmentize_molecule(smiles)
fragments = fragment_df.to_dict(orient="records")
return FragmentResponse(fragments=fragments)
except Exception as e:
raise HTTPException(status_code=500, detail=f"发生错误: {str(e)}")
@app.post("/generate", response_model=List[MoleculeOutput])
async def generate_molecules(request: GenerateRequest):
start_time = time.time() # 记录请求接受的时间
try:
# 调用 SMILES 生成逻辑
print("/generate请求开始时间:", start_time)
print("--------------/generate start---------------")
result = run_generate_runner(request.constSmiles, request.varSmiles, request.mainCls, request.minorCls, request.deltaValue, request.num)
end_time = time.time() # 记录生成结束的时间
duration = end_time - start_time # 计算用时
print("/generate请求结束时间:", end_time)
print(f"请求处理用时: {duration:.2f}秒,本次处理分子数量为 {request.num}")
return result
except Exception as e:
# 捕获异常并记录详细的错误信息,包括堆栈追踪
error_message = f"Error occurred: {str(e)}"
print(error_message) # 打印到控制台,或者使用 logging 模块记录到日志文件
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}") |