| import os |
| import re |
| import io |
| import pandas as pd |
| from datetime import datetime |
| import torch |
| from flask import Flask, request, jsonify, render_template, Response |
| from transformers import EsmTokenizer |
| from huggingface_hub import HfApi, hf_hub_download |
| from werkzeug.utils import secure_filename |
|
|
| |
| from model import FullKcatPredictor |
|
|
| |
| app = Flask(__name__) |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| CONFIG = { |
| "REPO_ID": "KangjieXu/CASKP-model", |
| "MODEL_FILENAME": "caskp_final_model.pt", |
| "ESM_MODEL_NAME": "facebook/esm2_t33_650M_UR50D", |
| |
| |
| "STRUCT_DIM": 1, |
| "D_MODEL": 256, |
| "D_MULTISCALE": 128, |
| "NUM_HEADS": 8, |
| "USE_AMSFF": True, |
| |
| |
| "DATASET_REPO_ID": "KangjieXu/CASKP-usage-logs" |
| } |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| |
| if HF_TOKEN: |
| try: |
| api = HfApi(token=HF_TOKEN) |
| print("Hugging Face Hub API 客户端初始化成功。") |
| except Exception as e: |
| print(f"HF API 初始化失败: {e}") |
| api = None |
| else: |
| api = None |
|
|
| |
| print(f"正在从Hub下载权重: {CONFIG['REPO_ID']}...") |
| try: |
| model_path = hf_hub_download(repo_id=CONFIG['REPO_ID'], filename=CONFIG['MODEL_FILENAME']) |
| except Exception as e: |
| raise RuntimeError(f"从Hub下载模型权重失败: {e}") |
|
|
| print("正在初始化模型结构...") |
| model = FullKcatPredictor( |
| esm_model_name=CONFIG["ESM_MODEL_NAME"], |
| struct_dim=CONFIG["STRUCT_DIM"], |
| d_model=CONFIG["D_MODEL"], |
| d_multiscale=CONFIG["D_MULTISCALE"], |
| num_heads=CONFIG["NUM_HEADS"], |
| dropout=0.0, |
| use_amsff=CONFIG["USE_AMSFF"] |
| ).to(DEVICE) |
|
|
| print(f"正在加载模型权重...") |
| state_dict = torch.load(model_path, map_location=DEVICE) |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
|
|
| print(f"正在加载 Tokenizer...") |
| tokenizer = EsmTokenizer.from_pretrained(CONFIG["ESM_MODEL_NAME"]) |
| print(f"CASKP 模型加载成功,运行在 {DEVICE} 设备上。") |
|
|
| |
| def clean_protein_sequence(sequence): |
| if pd.isna(sequence): return "" |
| sequence = re.sub(r'>.*\n', '', str(sequence)) |
| return "".join(sequence.split()) |
|
|
| def parse_structure_features(feat_str, target_dim): |
| """将逗号分隔的字符串转换为固定维度的列表""" |
| try: |
| if pd.isna(feat_str): |
| vals = [0.0] * target_dim |
| else: |
| vals = [float(x) for x in str(feat_str).split(',')] |
| except ValueError: |
| vals = [0.0] * target_dim |
| |
| |
| if len(vals) > target_dim: |
| vals = vals[:target_dim] |
| elif len(vals) < target_dim: |
| vals += [0.0] * (target_dim - len(vals)) |
| return vals |
|
|
| |
| @app.route('/') |
| def home(): |
| return render_template('index.html') |
|
|
| @app.route('/predict', methods=['POST']) |
| def predict(): |
| if 'data_file' not in request.files: |
| return jsonify({'error': '请求中未找到文件部分。'}), 400 |
| |
| file = request.files['data_file'] |
| if file.filename == '': |
| return jsonify({'error': '未选择文件。'}), 400 |
|
|
| try: |
| filename = secure_filename(file.filename) |
| if filename.endswith('.csv'): |
| df_full = pd.read_csv(file) |
| elif filename.endswith(('.xls', '.xlsx')): |
| df_full = pd.read_excel(file) |
| else: |
| return jsonify({'error': '不支持的文件类型。请上传 .csv 或 .xlsx 文件。'}), 400 |
|
|
| |
| required_columns = {'protein_sequence', 'structure_features'} |
| if not required_columns.issubset(df_full.columns): |
| return jsonify({'error': f'文件缺少必需的列。请包含: {list(required_columns)}'}), 400 |
| |
| |
| protein_seqs = df_full['protein_sequence'].apply(clean_protein_sequence).tolist() |
| |
| |
| struct_features_list = df_full['structure_features'].apply( |
| lambda x: parse_structure_features(x, CONFIG['STRUCT_DIM']) |
| ).tolist() |
|
|
| |
| struct_tensor = torch.tensor(struct_features_list, dtype=torch.float).to(DEVICE) |
| |
| |
| inputs = tokenizer(protein_seqs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) |
| |
| |
| with torch.no_grad(): |
| |
| log_kcat_preds = model(inputs['input_ids'], inputs['attention_mask'], struct_tensor) |
| |
| |
| |
| raw_preds = log_kcat_preds.cpu().numpy().tolist() |
| if isinstance(raw_preds[0], list): |
| log_kcat_list = [item for sublist in raw_preds for item in sublist] |
| else: |
| log_kcat_list = raw_preds |
|
|
| kcat_list = [10**val for val in log_kcat_list] |
|
|
| df_full['predicted_log10_kcat'] = [f"{v:.4f}" for v in log_kcat_list] |
| df_full['predicted_kcat_s_neg1'] = [f"{v:.4f}" for v in kcat_list] |
|
|
| |
| consent_given = request.form.get('consent_given') == 'true' |
| if api and consent_given: |
| try: |
| log_data = { |
| 'timestamp_utc': [datetime.utcnow().isoformat()] * len(protein_seqs), |
| 'protein_sequence': protein_seqs, |
| 'structure_features': [str(x) for x in struct_features_list], |
| 'predicted_log10_kcat': log_kcat_list |
| } |
| log_df = pd.DataFrame(log_data) |
| |
| log_buffer = io.StringIO() |
| log_df.to_csv(log_buffer, index=False) |
| log_bytes = log_buffer.getvalue().encode("utf-8") |
| |
| timestamp_str = datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S-%f') |
| log_filename_in_repo = f"logs/caskp_log_{timestamp_str}.csv" |
| |
| api.upload_file( |
| path_or_fileobj=log_bytes, |
| path_in_repo=log_filename_in_repo, |
| repo_id=CONFIG["DATASET_REPO_ID"], |
| repo_type="dataset", |
| commit_message=f"CASKP Log {timestamp_str}" |
| ) |
| print(f"日志上传成功: {log_filename_in_repo}") |
| except Exception as e: |
| print(f"【警告】日志记录失败: {e}") |
|
|
| |
| output_buffer = io.BytesIO() |
| output_filename = f"caskp_predictions_{os.path.splitext(filename)[0]}.csv" |
| df_full.to_csv(output_buffer, index=False) |
| output_buffer.seek(0) |
| |
| return Response( |
| output_buffer.getvalue(), |
| mimetype="text/csv", |
| headers={"Content-Disposition": f"attachment; filename=\"{output_filename}\""} |
| ) |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return jsonify({'error': f'服务器内部错误: {str(e)}'}), 500 |
|
|
| if __name__ == '__main__': |
| app.run(host='0.0.0.0', port=7860) |