File size: 6,861 Bytes
79cdf7b
2df9869
d6fdf05
2df9869
 
 
14e32f4
2df9869
 
 
14e32f4
 
 
 
2df9869
 
f41438b
2df9869
3541985
 
d38f766
2df9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7d65ed
 
 
 
 
 
 
 
 
 
2df9869
 
 
 
 
 
 
 
 
 
 
 
 
17b0e42
 
 
 
 
 
 
 
 
26cbe4e
17b0e42
 
2df9869
 
 
 
17b0e42
 
2df9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f41438b
 
 
2df9869
 
f41438b
2df9869
 
 
 
 
 
 
7a02397
2df9869
36f8047
98b1ab2
14e32f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b0e42
2df9869
 
 
 
 
 
 
98b1ab2
2df9869
 
79cdf7b
2df9869
 
79cdf7b
825e5af
2df9869
79cdf7b
 
413d237
79cdf7b
 
2df9869
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import time
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
import subprocess
from typing import List
from fragment_processor import fragmentize_molecule
from fastapi.responses import StreamingResponse
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors, QED, Draw
from io import BytesIO
from PIL import Image
import io
from generate import GenerateRunner
from dataset import Dataset
from combine_mol import connect_constVar_try
import sascorer
app = FastAPI()


class Fragment(BaseModel):
    variable_smiles: str
    constant_smiles: str
    record_id: str
    normalized_smiles: str
    attachment_order: int

class FragmentResponse(BaseModel):
    fragments: List[Fragment]



class GenerateRequest(BaseModel):
    constSmiles: str
    varSmiles: str
    mainCls: str
    minorCls: str
    deltaValue: str
    num: int

class MoleculeOutput(BaseModel):
    smile: str
    molwt: float
    tpsa: float
    slogp: float
    sa: float
    qed: float

class Options:
    def __init__(self, **entries):
        self.__dict__.update(entries)

def calculate_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # molwt = Descriptors.MolWt(mol)
    # tpsa = Descriptors.TPSA(mol)
    # slogp = Descriptors.MolLogP(mol)
    # sa = sascorer.calculateScore(mol) 
    # qed = QED.qed(mol)
    molwt = round(Descriptors.MolWt(mol), 1)  # 保留 1 位小数
    tpsa = round(Descriptors.TPSA(mol), 1)
    slogp = round(Descriptors.MolLogP(mol), 1)
    sa = round(sascorer.calculateScore(mol), 1)
    qed = round(QED.qed(mol), 1)

    # 检查除法前是否为 0
    if tpsa == 0:
        print("Warning: TPSA is zero, skipping division.")
        some_ratio = None  
    else:
        some_ratio = molwt / tpsa  # 安全的除法操作
    
    return {"molwt": molwt, "tpsa": tpsa, "slogp": slogp, "sa": sa, "qed": qed}

def run_generate_runner(const_smiles, var_smiles, main_cls, minor_cls, delta_value, num_samples):
    # 初始化生成器的配置选项
    opt = {
        'batch_size': num_samples,
        'data_path' :  './',
        'decode_type' :  'multinomial',
        'dev_no' :  0,
        'epoch' :  20,
        'model_choice' :  'transformer',
        'model_path' :  './raw_pretrain_frag/checkpoint',
        'num_samples' :  50,
        'overwrite' :  True,
        # 'save_directory' :  './demo_gen',
        'test_file_name' :  'test_cut',
        'vocab_path' :  './'
    }

    # 将 opt 字典转换为 Options 对象
    opt = Options(**opt)
    print("--------------opt---------------")
    print(opt)
    runner = GenerateRunner(opt)

    # 创建数据
    data = {
        "constantSMILES": const_smiles,
        "fromVarSMILES": var_smiles,
        "main_cls": main_cls,
        "minor_cls": minor_cls,
        "Delta_Value": delta_value
    }
    
    # 创建 Dataset 实例
    test_data = pd.DataFrame([data])
    dataset = Dataset(test_data, vocabulary=runner.vocab, tokenizer=runner.tokenizer, prediction_mode=True)

    # 生成 SMILES
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=num_samples, shuffle=False, collate_fn=Dataset.collate_fn)
    result = []

    for batch in dataloader:
        src, source_length, _, src_mask, _, _, df = batch
        src = src.to(runner.device)
        src_mask = src_mask.to(runner.device)
        source_length = source_length.to(runner.device)  # 将 source_length 也移到同一设备
        smiles_list = runner.sample(
            model_choice="transformer", 
            model=runner.model, 
            src=src, 
            src_mask=src_mask, 
            source_length=source_length, 
            decode_type="multinomial", 
            num_samples=num_samples
        )
        
        # 计算每个 SMILES 的化学性质
        for smiles_group in smiles_list:
            for smile in smiles_group:  # smiles_group 是一个子列表
                #链接新分子
                newsmile=connect_constVar_try(const_smiles,smile)
                descriptors = calculate_descriptors(newsmile)
                if descriptors:
                    result.append({
                        "smile": newsmile,
                        "molwt": descriptors['molwt'],
                        "tpsa": descriptors['tpsa'],
                        "slogp": descriptors['slogp'],
                        "sa": descriptors['sa'],
                        "qed": descriptors['qed']
                    })


    return result


@app.get("/smiles2img")
async def smiles2img(smiles: str = Query(..., description="SMILES string of the molecule")):
    
    # 生成分子对象
    print("---开始生成分子图像--")
    print(smiles)
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {"error": "Invalid SMILES string"}

    # 创建绘图对象
    drawer = Draw.MolDraw2DCairo(200, 200)
    drawer.SetFontSize(1.0)
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()

    # 将绘制的图像转换为PIL图像对象
    pil_image = Image.open(io.BytesIO(drawer.GetDrawingText()))

    # 创建字节流
    strIO = BytesIO()
    pil_image.save(strIO, "PNG")
    strIO.seek(0)

    # 返回图像作为流
    return StreamingResponse(strIO, media_type="image/png")

@app.get("/fragmentize", response_model=FragmentResponse)
async def fragmentize(smiles: str = Query(..., description="SMILES string of the molecule")):
    try:
        fragment_df = fragmentize_molecule(smiles)
        fragments = fragment_df.to_dict(orient="records")
        return FragmentResponse(fragments=fragments)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"发生错误: {str(e)}")

@app.post("/generate", response_model=List[MoleculeOutput])
async def generate_molecules(request: GenerateRequest):
    start_time = time.time()  # 记录请求接受的时间
    try:
        # 调用 SMILES 生成逻辑
        print("/generate请求开始时间:", start_time)
        print("--------------/generate start---------------")
        result = run_generate_runner(request.constSmiles, request.varSmiles, request.mainCls, request.minorCls, request.deltaValue, request.num)
        end_time = time.time()  # 记录生成结束的时间
        duration = end_time - start_time  # 计算用时

        print("/generate请求结束时间:", end_time)
        print(f"请求处理用时: {duration:.2f}秒,本次处理分子数量为 {request.num}")
        return result
    except Exception as e:
        # 捕获异常并记录详细的错误信息,包括堆栈追踪
        error_message = f"Error occurred: {str(e)}"
        print(error_message)  # 打印到控制台,或者使用 logging 模块记录到日志文件
        raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")