LLM-fastAPI / main.py
Songyou's picture
Update main.py
26cbe4e verified
raw
history blame
5.14 kB
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
import subprocess
from typing import List
from fragment_processor import fragmentize_molecule
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors, QED
from generate import GenerateRunner
from dataset import Dataset
import sascorer
app = FastAPI()
class Fragment(BaseModel):
variable_smiles: str
constant_smiles: str
record_id: str
normalized_smiles: str
attachment_order: int
class FragmentResponse(BaseModel):
fragments: List[Fragment]
class GenerateRequest(BaseModel):
constSmiles: str
varSmiles: str
mainCls: str
minorCls: str
deltaValue: str
num: int
class MoleculeOutput(BaseModel):
smile: str
molwt: float
tpsa: float
slogp: float
sa: float
qed: float
class Options:
def __init__(self, **entries):
self.__dict__.update(entries)
def calculate_descriptors(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
molwt = Descriptors.MolWt(mol)
tpsa = Descriptors.TPSA(mol)
slogp = Descriptors.MolLogP(mol)
sa = sascorer.calculateScore(mol)
qed = QED.qed(mol)
# 检查除法前是否为 0
if tpsa == 0:
print("Warning: TPSA is zero, skipping division.")
some_ratio = None
else:
some_ratio = molwt / tpsa # 安全的除法操作
return {"molwt": molwt, "tpsa": tpsa, "slogp": slogp, "sa": sa, "qed": qed}
def run_generate_runner(const_smiles, var_smiles, main_cls, minor_cls, delta_value, num_samples):
# 初始化生成器的配置选项
opt = {
'batch_size': num_samples,
'data_path' : './',
'decode_type' : 'multinomial',
'dev_no' : 0,
'epoch' : 20,
'model_choice' : 'transformer',
'model_path' : './raw_pretrain_frag/checkpoint',
'num_samples' : 50,
'overwrite' : True,
# 'save_directory' : './demo_gen',
'test_file_name' : 'test_cut',
'vocab_path' : './'
}
# 将 opt 字典转换为 Options 对象
opt = Options(**opt)
print("--------------opt---------------")
print(opt)
runner = GenerateRunner(opt)
# 创建数据
data = {
"constantSMILES": const_smiles,
"fromVarSMILES": var_smiles,
"main_cls": main_cls,
"minor_cls": minor_cls,
"Delta_Value": delta_value
}
# 创建 Dataset 实例
test_data = pd.DataFrame([data])
dataset = Dataset(test_data, vocabulary=runner.vocab, tokenizer=runner.tokenizer, prediction_mode=True)
# 生成 SMILES
dataloader = torch.utils.data.DataLoader(dataset, batch_size=num_samples, shuffle=False, collate_fn=Dataset.collate_fn)
result = []
for batch in dataloader:
src, source_length, _, src_mask, _, _, df = batch
src = src.to(runner.device)
src_mask = src_mask.to(runner.device)
source_length = source_length.to(runner.device) # 将 source_length 也移到同一设备
smiles_list = runner.sample(
model_choice="transformer",
model=runner.model,
src=src,
src_mask=src_mask,
source_length=source_length,
decode_type="multinomial",
num_samples=num_samples
)
# 计算每个 SMILES 的化学性质
for smiles_group in smiles_list:
for smile in smiles_group: # smiles_group 是一个子列表
descriptors = calculate_descriptors(smile)
if descriptors:
result.append({
"smile": smile,
"molwt": descriptors['molwt'],
"tpsa": descriptors['tpsa'],
"slogp": descriptors['slogp'],
"sa": descriptors['sa'],
"qed": descriptors['qed']
})
return result
@app.get("/fragmentize", response_model=FragmentResponse)
async def fragmentize(smiles: str = Query(..., description="SMILES string of the molecule")):
try:
fragment_df = fragmentize_molecule(smiles)
fragments = fragment_df.to_dict(orient="records")
return FragmentResponse(fragments=fragments)
except Exception as e:
raise HTTPException(status_code=500, detail=f"发生错误: {str(e)}")
@app.post("/generate", response_model=List[MoleculeOutput])
async def generate_molecules(request: GenerateRequest):
try:
# 调用 SMILES 生成逻辑
print("123123")
result = run_generate_runner(request.constSmiles, request.varSmiles, request.mainCls, request.minorCls, request.deltaValue, request.num)
return result
except Exception as e:
# 捕获异常并记录详细的错误信息,包括堆栈追踪
error_message = f"Error occurred: {str(e)}"
print(error_message) # 打印到控制台,或者使用 logging 模块记录到日志文件
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")