CASKP / app.py
KangjieXu's picture
Update app.py
de091d1 verified
import os
import re
import io
import pandas as pd
from datetime import datetime
import torch
from flask import Flask, request, jsonify, render_template, Response
from transformers import EsmTokenizer
from huggingface_hub import HfApi, hf_hub_download
from werkzeug.utils import secure_filename
# 导入你的 CASKP 模型文件 (确保文件名是 model.py)
from model import FullKcatPredictor
# --- 1. 初始化和配置 ---
app = Flask(__name__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# === CASKP 配置 ===
CONFIG = {
"REPO_ID": "KangjieXu/CASKP-model",
"MODEL_FILENAME": "caskp_final_model.pt",
"ESM_MODEL_NAME": "facebook/esm2_t33_650M_UR50D",
# 【注意】这里必须和你训练时的维度一致!如果只有一个特征就是 1
"STRUCT_DIM": 1,
"D_MODEL": 256,
"D_MULTISCALE": 128,
"NUM_HEADS": 8,
"USE_AMSFF": True,
# 日志仓库 (你可以用同一个,也可以新建一个 CASKP-usage-logs)
"DATASET_REPO_ID": "KangjieXu/CASKP-usage-logs"
}
HF_TOKEN = os.getenv("HF_TOKEN")
# 初始化 HF API
if HF_TOKEN:
try:
api = HfApi(token=HF_TOKEN)
print("Hugging Face Hub API 客户端初始化成功。")
except Exception as e:
print(f"HF API 初始化失败: {e}")
api = None
else:
api = None
# --- 2. 模型加载 ---
print(f"正在从Hub下载权重: {CONFIG['REPO_ID']}...")
try:
model_path = hf_hub_download(repo_id=CONFIG['REPO_ID'], filename=CONFIG['MODEL_FILENAME'])
except Exception as e:
raise RuntimeError(f"从Hub下载模型权重失败: {e}")
print("正在初始化模型结构...")
model = FullKcatPredictor(
esm_model_name=CONFIG["ESM_MODEL_NAME"],
struct_dim=CONFIG["STRUCT_DIM"],
d_model=CONFIG["D_MODEL"],
d_multiscale=CONFIG["D_MULTISCALE"],
num_heads=CONFIG["NUM_HEADS"],
dropout=0.0,
use_amsff=CONFIG["USE_AMSFF"]
).to(DEVICE)
print(f"正在加载模型权重...")
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
model.eval()
print(f"正在加载 Tokenizer...")
tokenizer = EsmTokenizer.from_pretrained(CONFIG["ESM_MODEL_NAME"])
print(f"CASKP 模型加载成功,运行在 {DEVICE} 设备上。")
# --- 3. 辅助函数 ---
def clean_protein_sequence(sequence):
if pd.isna(sequence): return ""
sequence = re.sub(r'>.*\n', '', str(sequence))
return "".join(sequence.split())
def parse_structure_features(feat_str, target_dim):
"""将逗号分隔的字符串转换为固定维度的列表"""
try:
if pd.isna(feat_str):
vals = [0.0] * target_dim
else:
vals = [float(x) for x in str(feat_str).split(',')]
except ValueError:
vals = [0.0] * target_dim # 解析失败归零
# 对齐维度
if len(vals) > target_dim:
vals = vals[:target_dim]
elif len(vals) < target_dim:
vals += [0.0] * (target_dim - len(vals))
return vals
# --- 4. Flask 路由 ---
@app.route('/')
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
if 'data_file' not in request.files:
return jsonify({'error': '请求中未找到文件部分。'}), 400
file = request.files['data_file']
if file.filename == '':
return jsonify({'error': '未选择文件。'}), 400
try:
filename = secure_filename(file.filename)
if filename.endswith('.csv'):
df_full = pd.read_csv(file)
elif filename.endswith(('.xls', '.xlsx')):
df_full = pd.read_excel(file)
else:
return jsonify({'error': '不支持的文件类型。请上传 .csv 或 .xlsx 文件。'}), 400
# CASKP 的必需列不同:需要 structure_features 而不是 substrate_smiles
required_columns = {'protein_sequence', 'structure_features'}
if not required_columns.issubset(df_full.columns):
return jsonify({'error': f'文件缺少必需的列。请包含: {list(required_columns)}'}), 400
# 数据预处理
protein_seqs = df_full['protein_sequence'].apply(clean_protein_sequence).tolist()
# 处理结构特征列 (字符串 -> 列表)
struct_features_list = df_full['structure_features'].apply(
lambda x: parse_structure_features(x, CONFIG['STRUCT_DIM'])
).tolist()
# 转换为 Tensor
struct_tensor = torch.tensor(struct_features_list, dtype=torch.float).to(DEVICE)
# Tokenizer 处理
inputs = tokenizer(protein_seqs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
# 推理
with torch.no_grad():
# CASKP 模型接收三个参数
log_kcat_preds = model(inputs['input_ids'], inputs['attention_mask'], struct_tensor)
# 结果处理
# 确保展平为一维列表
raw_preds = log_kcat_preds.cpu().numpy().tolist()
if isinstance(raw_preds[0], list):
log_kcat_list = [item for sublist in raw_preds for item in sublist]
else:
log_kcat_list = raw_preds
kcat_list = [10**val for val in log_kcat_list]
df_full['predicted_log10_kcat'] = [f"{v:.4f}" for v in log_kcat_list]
df_full['predicted_kcat_s_neg1'] = [f"{v:.4f}" for v in kcat_list]
# --- 日志记录 ---
consent_given = request.form.get('consent_given') == 'true'
if api and consent_given:
try:
log_data = {
'timestamp_utc': [datetime.utcnow().isoformat()] * len(protein_seqs),
'protein_sequence': protein_seqs,
'structure_features': [str(x) for x in struct_features_list], # 记录解析后的特征
'predicted_log10_kcat': log_kcat_list
}
log_df = pd.DataFrame(log_data)
log_buffer = io.StringIO()
log_df.to_csv(log_buffer, index=False)
log_bytes = log_buffer.getvalue().encode("utf-8")
timestamp_str = datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S-%f')
log_filename_in_repo = f"logs/caskp_log_{timestamp_str}.csv"
api.upload_file(
path_or_fileobj=log_bytes,
path_in_repo=log_filename_in_repo,
repo_id=CONFIG["DATASET_REPO_ID"],
repo_type="dataset",
commit_message=f"CASKP Log {timestamp_str}"
)
print(f"日志上传成功: {log_filename_in_repo}")
except Exception as e:
print(f"【警告】日志记录失败: {e}")
# 返回结果
output_buffer = io.BytesIO()
output_filename = f"caskp_predictions_{os.path.splitext(filename)[0]}.csv"
df_full.to_csv(output_buffer, index=False)
output_buffer.seek(0)
return Response(
output_buffer.getvalue(),
mimetype="text/csv",
headers={"Content-Disposition": f"attachment; filename=\"{output_filename}\""}
)
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': f'服务器内部错误: {str(e)}'}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)