import time from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel import subprocess from typing import List from fragment_processor import fragmentize_molecule from fastapi.responses import StreamingResponse import torch import pandas as pd from rdkit import Chem from rdkit.Chem import Descriptors, QED, Draw from io import BytesIO from PIL import Image import io from generate import GenerateRunner from dataset import Dataset from combine_mol import connect_constVar_try import sascorer app = FastAPI() class Fragment(BaseModel): variable_smiles: str constant_smiles: str record_id: str normalized_smiles: str attachment_order: int class FragmentResponse(BaseModel): fragments: List[Fragment] class GenerateRequest(BaseModel): constSmiles: str varSmiles: str mainCls: str minorCls: str deltaValue: str num: int class MoleculeOutput(BaseModel): smile: str molwt: float tpsa: float slogp: float sa: float qed: float class Options: def __init__(self, **entries): self.__dict__.update(entries) def calculate_descriptors(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return None # molwt = Descriptors.MolWt(mol) # tpsa = Descriptors.TPSA(mol) # slogp = Descriptors.MolLogP(mol) # sa = sascorer.calculateScore(mol) # qed = QED.qed(mol) molwt = round(Descriptors.MolWt(mol), 1) # 保留 1 位小数 tpsa = round(Descriptors.TPSA(mol), 1) slogp = round(Descriptors.MolLogP(mol), 1) sa = round(sascorer.calculateScore(mol), 1) qed = round(QED.qed(mol), 1) # 检查除法前是否为 0 if tpsa == 0: print("Warning: TPSA is zero, skipping division.") some_ratio = None else: some_ratio = molwt / tpsa # 安全的除法操作 return {"molwt": molwt, "tpsa": tpsa, "slogp": slogp, "sa": sa, "qed": qed} def run_generate_runner(const_smiles, var_smiles, main_cls, minor_cls, delta_value, num_samples): # 初始化生成器的配置选项 opt = { 'batch_size': num_samples, 'data_path' : './', 'decode_type' : 'multinomial', 'dev_no' : 0, 'epoch' : 20, 'model_choice' : 'transformer', 'model_path' : './raw_pretrain_frag/checkpoint', 'num_samples' : 50, 'overwrite' : True, # 'save_directory' : './demo_gen', 'test_file_name' : 'test_cut', 'vocab_path' : './' } # 将 opt 字典转换为 Options 对象 opt = Options(**opt) print("--------------opt---------------") print(opt) runner = GenerateRunner(opt) # 创建数据 data = { "constantSMILES": const_smiles, "fromVarSMILES": var_smiles, "main_cls": main_cls, "minor_cls": minor_cls, "Delta_Value": delta_value } # 创建 Dataset 实例 test_data = pd.DataFrame([data]) dataset = Dataset(test_data, vocabulary=runner.vocab, tokenizer=runner.tokenizer, prediction_mode=True) # 生成 SMILES dataloader = torch.utils.data.DataLoader(dataset, batch_size=num_samples, shuffle=False, collate_fn=Dataset.collate_fn) result = [] for batch in dataloader: src, source_length, _, src_mask, _, _, df = batch src = src.to(runner.device) src_mask = src_mask.to(runner.device) source_length = source_length.to(runner.device) # 将 source_length 也移到同一设备 smiles_list = runner.sample( model_choice="transformer", model=runner.model, src=src, src_mask=src_mask, source_length=source_length, decode_type="multinomial", num_samples=num_samples ) # 计算每个 SMILES 的化学性质 for smiles_group in smiles_list: for smile in smiles_group: # smiles_group 是一个子列表 #链接新分子 newsmile=connect_constVar_try(const_smiles,smile) descriptors = calculate_descriptors(newsmile) if descriptors: result.append({ "smile": newsmile, "molwt": descriptors['molwt'], "tpsa": descriptors['tpsa'], "slogp": descriptors['slogp'], "sa": descriptors['sa'], "qed": descriptors['qed'] }) return result @app.get("/smiles2img") async def smiles2img(smiles: str = Query(..., description="SMILES string of the molecule")): # 生成分子对象 print("---开始生成分子图像--") print(smiles) mol = Chem.MolFromSmiles(smiles) if mol is None: return {"error": "Invalid SMILES string"} # 创建绘图对象 drawer = Draw.MolDraw2DCairo(200, 200) drawer.SetFontSize(1.0) drawer.DrawMolecule(mol) drawer.FinishDrawing() # 将绘制的图像转换为PIL图像对象 pil_image = Image.open(io.BytesIO(drawer.GetDrawingText())) # 创建字节流 strIO = BytesIO() pil_image.save(strIO, "PNG") strIO.seek(0) # 返回图像作为流 return StreamingResponse(strIO, media_type="image/png") @app.get("/fragmentize", response_model=FragmentResponse) async def fragmentize(smiles: str = Query(..., description="SMILES string of the molecule")): try: fragment_df = fragmentize_molecule(smiles) fragments = fragment_df.to_dict(orient="records") return FragmentResponse(fragments=fragments) except Exception as e: raise HTTPException(status_code=500, detail=f"发生错误: {str(e)}") @app.post("/generate", response_model=List[MoleculeOutput]) async def generate_molecules(request: GenerateRequest): start_time = time.time() # 记录请求接受的时间 try: # 调用 SMILES 生成逻辑 print("/generate请求开始时间:", start_time) print("--------------/generate start---------------") result = run_generate_runner(request.constSmiles, request.varSmiles, request.mainCls, request.minorCls, request.deltaValue, request.num) end_time = time.time() # 记录生成结束的时间 duration = end_time - start_time # 计算用时 print("/generate请求结束时间:", end_time) print(f"请求处理用时: {duration:.2f}秒,本次处理分子数量为 {request.num}") return result except Exception as e: # 捕获异常并记录详细的错误信息,包括堆栈追踪 error_message = f"Error occurred: {str(e)}" print(error_message) # 打印到控制台,或者使用 logging 模块记录到日志文件 raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")