Spaces:
Runtime error
Runtime error
Upload app.py
#3
by YichaoLiu - opened
app.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from fastapi import FastAPI, HTTPException
|
| 3 |
+
from starlette.staticfiles import StaticFiles
|
| 4 |
+
import uvicorn
|
| 5 |
+
import logging
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import time
|
| 9 |
+
import requests
|
| 10 |
+
import json
|
| 11 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
# Set up logging configuration
|
| 14 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# API configurations
|
| 18 |
+
API_BASE_URL = "https://songyou-llm-fastapi.hf.space"
|
| 19 |
+
FRAGMENT_ENDPOINT = f"{API_BASE_URL}/fragmentize"
|
| 20 |
+
GENERATE_ENDPOINT = f"{API_BASE_URL}/generate"
|
| 21 |
+
|
| 22 |
+
# Load parameters from configuration file
|
| 23 |
+
try:
|
| 24 |
+
with open('param.json', 'r') as f:
|
| 25 |
+
params = json.load(f)
|
| 26 |
+
logger.info("Successfully loaded parameter configuration")
|
| 27 |
+
except Exception as e:
|
| 28 |
+
logger.error(f"Error loading parameter configuration: {str(e)}")
|
| 29 |
+
raise
|
| 30 |
+
|
| 31 |
+
# Data models
|
| 32 |
+
class SmilesData(BaseModel):
|
| 33 |
+
"""Model for SMILES data received from frontend"""
|
| 34 |
+
smiles: str
|
| 35 |
+
|
| 36 |
+
class GenerateRequest(BaseModel):
|
| 37 |
+
"""Request model for generate endpoint with updated fields"""
|
| 38 |
+
constSmiles: str
|
| 39 |
+
varSmiles: str
|
| 40 |
+
mainCls: str
|
| 41 |
+
minorCls: str
|
| 42 |
+
deltaValue: str
|
| 43 |
+
targetName: str = "target1" # default value
|
| 44 |
+
num: int
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Helper functions for metric handling
|
| 55 |
+
def get_metrics_for_objective(objective: str) -> List[str]:
|
| 56 |
+
"""Get the corresponding metrics for a given objective"""
|
| 57 |
+
if objective == "None" or objective not in params["Metrics"]:
|
| 58 |
+
return ["None"]
|
| 59 |
+
return ["None"] + params["Metrics"].get(objective, [])
|
| 60 |
+
|
| 61 |
+
def get_metric_full_name(objective: str, metric: str) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Constructs the full metric name based on objective and metric.
|
| 64 |
+
For general physical properties, returns just the metric name.
|
| 65 |
+
For others, returns the metric name as is.
|
| 66 |
+
"""
|
| 67 |
+
if objective == "general physical properties":
|
| 68 |
+
return metric
|
| 69 |
+
return f"{metric}"
|
| 70 |
+
|
| 71 |
+
def get_metric_type(metric_name: str) -> str:
|
| 72 |
+
"""
|
| 73 |
+
Determines if a metric is boolean or sequential based on the BoolOrSeq mapping.
|
| 74 |
+
Returns 'bool', 'seq', or '' if not found.
|
| 75 |
+
"""
|
| 76 |
+
metric_type = params["BoolOrSeq"].get(metric_name, "")
|
| 77 |
+
logger.debug(f"Metric type for {metric_name}: {metric_type}")
|
| 78 |
+
return metric_type
|
| 79 |
+
|
| 80 |
+
def get_delta_choices(metric_type: str) -> List[str]:
|
| 81 |
+
"""Returns the appropriate choices for delta value based on metric type."""
|
| 82 |
+
if metric_type == "bool":
|
| 83 |
+
return params["ImprovementAnticipationBool"]
|
| 84 |
+
elif metric_type == "seq":
|
| 85 |
+
return params["ImprovementAnticipationSeq"]
|
| 86 |
+
return []
|
| 87 |
+
|
| 88 |
+
def validate_metric_combination(objective: str, metric: str) -> bool:
|
| 89 |
+
"""
|
| 90 |
+
Validates if the objective-metric combination is valid.
|
| 91 |
+
Returns True if valid, False otherwise.
|
| 92 |
+
"""
|
| 93 |
+
if objective == "None" or metric == "None":
|
| 94 |
+
logger.debug(f"Invalid objective or metric: {objective} - {metric}")
|
| 95 |
+
return False
|
| 96 |
+
if objective not in params["Metrics"]:
|
| 97 |
+
logger.debug(f"Objective not found in metrics: {objective}")
|
| 98 |
+
return False
|
| 99 |
+
if metric not in params["Metrics"].get(objective, []):
|
| 100 |
+
logger.debug(f"Metric not found in objective: {metric}")
|
| 101 |
+
return False
|
| 102 |
+
logger.debug(f"Valid metric combination: {objective} - {metric}")
|
| 103 |
+
return True
|
| 104 |
+
|
| 105 |
+
def handle_generate_analogs(
|
| 106 |
+
main_cls: str,
|
| 107 |
+
minor_cls: str,
|
| 108 |
+
number: int,
|
| 109 |
+
bool_delta_val: str,
|
| 110 |
+
seq_delta_val: str,
|
| 111 |
+
const_smiles: str,
|
| 112 |
+
var_smiles: str,
|
| 113 |
+
metric_type: str
|
| 114 |
+
) -> pd.DataFrame:
|
| 115 |
+
"""
|
| 116 |
+
Handles the generation of analogs with appropriate delta value selection and error handling.
|
| 117 |
+
This function serves as the bridge between the UI and the generate_analogs API call.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
main_cls (str): The main objective classification
|
| 121 |
+
minor_cls (str): The specific metric
|
| 122 |
+
number (int): Number of analogs to generate
|
| 123 |
+
bool_delta_val (str): Selected delta value for boolean metrics
|
| 124 |
+
seq_delta_val (str): Selected delta value for sequential metrics
|
| 125 |
+
const_smiles (str): Constant fragment SMILES
|
| 126 |
+
var_smiles (str): Variable fragment SMILES
|
| 127 |
+
metric_type (str): Type of metric ('bool' or 'seq')
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
pd.DataFrame: DataFrame containing the generated analogs and their properties
|
| 131 |
+
"""
|
| 132 |
+
try:
|
| 133 |
+
# Input validation
|
| 134 |
+
if not all([main_cls, minor_cls, const_smiles, var_smiles]):
|
| 135 |
+
logger.error("Missing required inputs")
|
| 136 |
+
return pd.DataFrame()
|
| 137 |
+
|
| 138 |
+
if not validate_metric_combination(main_cls, minor_cls):
|
| 139 |
+
logger.error(f"Invalid metric combination: {main_cls} - {minor_cls}")
|
| 140 |
+
return pd.DataFrame()
|
| 141 |
+
|
| 142 |
+
# Select appropriate delta value based on metric type
|
| 143 |
+
if metric_type not in ["bool", "seq"]:
|
| 144 |
+
logger.error(f"Invalid metric type: {metric_type}")
|
| 145 |
+
return pd.DataFrame()
|
| 146 |
+
|
| 147 |
+
delta_value = bool_delta_val if metric_type == "bool" else seq_delta_val
|
| 148 |
+
|
| 149 |
+
# Generate analogs using the API
|
| 150 |
+
analogs_data = generate_analogs(
|
| 151 |
+
main_cls=main_cls,
|
| 152 |
+
minor_cls=minor_cls,
|
| 153 |
+
number=number,
|
| 154 |
+
delta_value=delta_value,
|
| 155 |
+
const_smiles=const_smiles,
|
| 156 |
+
var_smiles=var_smiles
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if not analogs_data:
|
| 160 |
+
logger.warning("No analogs generated")
|
| 161 |
+
return pd.DataFrame()
|
| 162 |
+
|
| 163 |
+
return update_output_table(analogs_data)
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Error in handle_generate_analogs: {str(e)}")
|
| 167 |
+
return pd.DataFrame()
|
| 168 |
+
|
| 169 |
+
# Update the fragment_molecule function to handle the new response format
|
| 170 |
+
def fragment_molecule(smiles: str) -> Tuple[str, str, str]:
|
| 171 |
+
"""
|
| 172 |
+
Call the fragment API endpoint to get molecule fragments
|
| 173 |
+
Returns: List of fragments with their details
|
| 174 |
+
"""
|
| 175 |
+
try:
|
| 176 |
+
logger.info(f"Calling fragment API with SMILES: {smiles}")
|
| 177 |
+
response = requests.get(f"{FRAGMENT_ENDPOINT}?smiles={smiles}")
|
| 178 |
+
response.raise_for_status()
|
| 179 |
+
data = response.json()
|
| 180 |
+
logger.info(f"Fragment API response: {data}")
|
| 181 |
+
|
| 182 |
+
# Return empty values if no fragments found
|
| 183 |
+
if not data.get("fragments"):
|
| 184 |
+
return "", "", ""
|
| 185 |
+
|
| 186 |
+
# Return the first fragment by default
|
| 187 |
+
first_fragment = data["fragments"][0]
|
| 188 |
+
return (
|
| 189 |
+
first_fragment.get("constant_smiles", ""),
|
| 190 |
+
first_fragment.get("variable_smiles", ""),
|
| 191 |
+
str(first_fragment.get("attachment_order", ""))
|
| 192 |
+
)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"Fragment API call failed: {str(e)}")
|
| 195 |
+
return "", "", ""
|
| 196 |
+
|
| 197 |
+
def generate_analogs(
|
| 198 |
+
main_cls: str,
|
| 199 |
+
minor_cls: str,
|
| 200 |
+
number: int,
|
| 201 |
+
delta_value: str,
|
| 202 |
+
const_smiles: str,
|
| 203 |
+
var_smiles: str
|
| 204 |
+
) -> List[Dict[str, Any]]:
|
| 205 |
+
"""
|
| 206 |
+
Generate molecule analogs using the generate API endpoint with improved error handling
|
| 207 |
+
and validation.
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
# Validate inputs
|
| 211 |
+
if not all([const_smiles, var_smiles, main_cls, minor_cls, delta_value]):
|
| 212 |
+
logger.error("Missing required inputs for generate_analogs")
|
| 213 |
+
return []
|
| 214 |
+
|
| 215 |
+
# Create API request
|
| 216 |
+
payload = GenerateRequest(
|
| 217 |
+
constSmiles=const_smiles,
|
| 218 |
+
varSmiles=var_smiles,
|
| 219 |
+
mainCls=main_cls if main_cls != "None" else "",
|
| 220 |
+
minorCls=minor_cls if minor_cls != "None" else "",
|
| 221 |
+
deltaValue=delta_value,
|
| 222 |
+
num=int(number)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
logger.info(f"Calling generate API with payload: {payload.dict()}")
|
| 226 |
+
|
| 227 |
+
# Make API request
|
| 228 |
+
response = requests.post(
|
| 229 |
+
GENERATE_ENDPOINT,
|
| 230 |
+
headers={'Content-Type': 'application/json'},
|
| 231 |
+
json=payload.dict(),
|
| 232 |
+
timeout=30
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
response.raise_for_status()
|
| 236 |
+
results = response.json()
|
| 237 |
+
|
| 238 |
+
if not isinstance(results, list):
|
| 239 |
+
logger.error(f"Unexpected response format: {results}")
|
| 240 |
+
return []
|
| 241 |
+
|
| 242 |
+
logger.info(f"Successfully generated {len(results)} analogs")
|
| 243 |
+
return results
|
| 244 |
+
|
| 245 |
+
except requests.exceptions.Timeout:
|
| 246 |
+
logger.error("Generate API request timed out")
|
| 247 |
+
return []
|
| 248 |
+
except requests.exceptions.RequestException as e:
|
| 249 |
+
logger.error(f"Generate API request failed: {str(e)}")
|
| 250 |
+
return []
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.error(f"Unexpected error in generate_analogs: {str(e)}")
|
| 253 |
+
return []
|
| 254 |
+
|
| 255 |
+
def update_output_table(data: List[Dict[str, Any]]) -> pd.DataFrame:
|
| 256 |
+
"""Convert API response data to pandas DataFrame for display"""
|
| 257 |
+
try:
|
| 258 |
+
df = pd.DataFrame(data)
|
| 259 |
+
return df
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"Error creating DataFrame: {str(e)}")
|
| 262 |
+
return pd.DataFrame()
|
| 263 |
+
|
| 264 |
+
def save_to_csv(data: pd.DataFrame, selected_only: bool = False) -> Optional[str]:
|
| 265 |
+
"""Save data to CSV file"""
|
| 266 |
+
try:
|
| 267 |
+
filename = f"molecule_analogs_{int(time.time())}.csv"
|
| 268 |
+
data.to_csv(filename, index=False)
|
| 269 |
+
return filename
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.error(f"Error saving to CSV: {str(e)}")
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
# FastAPI app initialization
|
| 275 |
+
app = FastAPI()
|
| 276 |
+
|
| 277 |
+
# Mount Ketcher static files
|
| 278 |
+
app.mount("/ketcher", StaticFiles(directory="ketcher"), name="ketcher")
|
| 279 |
+
|
| 280 |
+
@app.post("/update_smiles")
|
| 281 |
+
async def update_smiles(data: SmilesData):
|
| 282 |
+
"""Endpoint to receive SMILES data from frontend"""
|
| 283 |
+
try:
|
| 284 |
+
logger.info(f"Received SMILES from front-end: {data.smiles}")
|
| 285 |
+
return {"status": "ok", "received_smiles": data.smiles}
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"Error processing SMILES update: {str(e)}")
|
| 288 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 289 |
+
|
| 290 |
+
# Ketcher interface HTML template
|
| 291 |
+
KETCHER_HTML = r'''
|
| 292 |
+
<iframe id="ifKetcher" src="/ketcher/index.html" width="100%" height="600px" style="border: 1px solid #ccc;"></iframe>
|
| 293 |
+
|
| 294 |
+
<script>
|
| 295 |
+
console.log("[Front-end] Ketcher-Gradio integration script loaded.");
|
| 296 |
+
|
| 297 |
+
let ketcher = null;
|
| 298 |
+
let lastSmiles = '';
|
| 299 |
+
|
| 300 |
+
function findSmilesInput() {
|
| 301 |
+
const inputContainer = document.getElementById('combined_smiles_input');
|
| 302 |
+
if (!inputContainer) {
|
| 303 |
+
console.warn("[Front-end] combined_smiles_input element not found.");
|
| 304 |
+
return null;
|
| 305 |
+
}
|
| 306 |
+
const input = inputContainer.querySelector('input[type="text"]');
|
| 307 |
+
return input;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
function updateGradioInput(smiles) {
|
| 311 |
+
const input = findSmilesInput();
|
| 312 |
+
if (input && input.value !== smiles) {
|
| 313 |
+
input.value = smiles;
|
| 314 |
+
input.dispatchEvent(new Event('input', { bubbles: true }));
|
| 315 |
+
console.log("[Front-end] Updated Gradio input with SMILES:", smiles);
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
async function handleKetcherChange() {
|
| 320 |
+
console.log("[Front-end] handleKetcherChange called, retrieving SMILES...");
|
| 321 |
+
try {
|
| 322 |
+
const smiles = await ketcher.getSmiles({ arom: false });
|
| 323 |
+
console.log("[Front-end] SMILES retrieved from Ketcher:", smiles);
|
| 324 |
+
if (smiles !== lastSmiles) {
|
| 325 |
+
lastSmiles = smiles;
|
| 326 |
+
updateGradioInput(smiles);
|
| 327 |
+
|
| 328 |
+
fetch('/update_smiles', {
|
| 329 |
+
method: 'POST',
|
| 330 |
+
headers: {'Content-Type': 'application/json'},
|
| 331 |
+
body: JSON.stringify({smiles: smiles})
|
| 332 |
+
})
|
| 333 |
+
.then(res => res.json())
|
| 334 |
+
.then(data => {
|
| 335 |
+
console.log("[Front-end] Backend response:", data);
|
| 336 |
+
})
|
| 337 |
+
.catch(err => console.error("[Front-end] Error sending SMILES to backend:", err));
|
| 338 |
+
}
|
| 339 |
+
} catch (err) {
|
| 340 |
+
console.error("[Front-end] Error getting SMILES from Ketcher:", err);
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
function initKetcher() {
|
| 345 |
+
console.log("[Front-end] initKetcher started.");
|
| 346 |
+
const iframe = document.getElementById('ifKetcher');
|
| 347 |
+
if (!iframe) {
|
| 348 |
+
console.error("[Front-end] iframe not found.");
|
| 349 |
+
setTimeout(initKetcher, 500);
|
| 350 |
+
return;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
const ketcherWindow = iframe.contentWindow;
|
| 354 |
+
if (!ketcherWindow || !ketcherWindow.ketcher) {
|
| 355 |
+
console.log("[Front-end] ketcher not yet available in iframe, retrying...");
|
| 356 |
+
setTimeout(initKetcher, 500);
|
| 357 |
+
return;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
ketcher = ketcherWindow.ketcher;
|
| 361 |
+
console.log("[Front-end] Ketcher instance acquired:", ketcher);
|
| 362 |
+
|
| 363 |
+
ketcher.setMolecule('C').then(() => {
|
| 364 |
+
console.log("[Front-end] Initial molecule set to 'C'.");
|
| 365 |
+
});
|
| 366 |
+
|
| 367 |
+
const editor = ketcher.editor;
|
| 368 |
+
console.log("[Front-end] Editor object:", editor);
|
| 369 |
+
|
| 370 |
+
let eventBound = false;
|
| 371 |
+
if (editor && typeof editor.subscribe === 'function') {
|
| 372 |
+
console.log("[Front-end] Using editor.subscribe('change', ...)");
|
| 373 |
+
editor.subscribe('change', handleKetcherChange);
|
| 374 |
+
eventBound = true;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
if (!eventBound) {
|
| 378 |
+
console.error("[Front-end] No suitable event binding found. Check Ketcher version and event API.");
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
document.getElementById('ifKetcher').addEventListener('load', () => {
|
| 383 |
+
console.log("[Front-end] iframe loaded. Initializing Ketcher in 1s...");
|
| 384 |
+
setTimeout(initKetcher, 1000);
|
| 385 |
+
});
|
| 386 |
+
</script>
|
| 387 |
+
'''
|
| 388 |
+
|
| 389 |
+
def create_combined_interface():
|
| 390 |
+
"""
|
| 391 |
+
Creates the main Gradio interface combining Ketcher, molecule fragmentation,
|
| 392 |
+
and analog generation functionalities with fragment selection.
|
| 393 |
+
"""
|
| 394 |
+
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
| 395 |
+
gr.Markdown("# Fragment Optimization Tools with Ketcher")
|
| 396 |
+
|
| 397 |
+
# Main layout with two columns
|
| 398 |
+
with gr.Row():
|
| 399 |
+
# Left column - Ketcher editor
|
| 400 |
+
with gr.Column(scale=2):
|
| 401 |
+
gr.HTML(KETCHER_HTML)
|
| 402 |
+
|
| 403 |
+
# Right column - Controls and inputs
|
| 404 |
+
with gr.Column(scale=1):
|
| 405 |
+
# SMILES Input section
|
| 406 |
+
with gr.Group():
|
| 407 |
+
gr.Markdown("### Input SMILES (From Ketcher)")
|
| 408 |
+
combined_smiles_input = gr.Textbox(
|
| 409 |
+
label="",
|
| 410 |
+
value="C",
|
| 411 |
+
placeholder="SMILES from Ketcher will appear here",
|
| 412 |
+
elem_id="combined_smiles_input"
|
| 413 |
+
)
|
| 414 |
+
with gr.Row():
|
| 415 |
+
get_ketcher_smiles_btn = gr.Button("Get SMILES from Ketcher", variant="primary")
|
| 416 |
+
fragment_btn = gr.Button("Find Fragments", variant="secondary")
|
| 417 |
+
|
| 418 |
+
# Fragment Selection section
|
| 419 |
+
# Fragment Selection section
|
| 420 |
+
# Fragment Selection section
|
| 421 |
+
with gr.Group():
|
| 422 |
+
gr.Markdown("### Available Fragments")
|
| 423 |
+
gr.Markdown("""
|
| 424 |
+
Select a fragmentation pattern:
|
| 425 |
+
- Variable Fragment: Part that will be modified
|
| 426 |
+
- Constant Fragment: Part that remains unchanged
|
| 427 |
+
- Order: Attachment point pattern between fragments
|
| 428 |
+
""")
|
| 429 |
+
fragments_table = gr.Dataframe(
|
| 430 |
+
headers=["Variable Fragment", "Constant Fragment", "Order"],
|
| 431 |
+
type="array",
|
| 432 |
+
interactive=True,
|
| 433 |
+
label="Click a row to select fragmentation pattern",
|
| 434 |
+
# Remove the invalid parameters
|
| 435 |
+
wrap=True, # Allow text wrapping for long SMILES strings
|
| 436 |
+
row_count=10 # Show 10 rows at a time
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Selected Fragment Display
|
| 440 |
+
with gr.Group():
|
| 441 |
+
gr.Markdown("### Selected Fragment")
|
| 442 |
+
with gr.Row():
|
| 443 |
+
constant_frag_input = gr.Textbox(
|
| 444 |
+
label="Constant Fragment",
|
| 445 |
+
placeholder="SMILES of constant fragment",
|
| 446 |
+
interactive=True
|
| 447 |
+
)
|
| 448 |
+
variable_frag_input = gr.Textbox(
|
| 449 |
+
label="Variable Fragment",
|
| 450 |
+
placeholder="SMILES of variable fragment",
|
| 451 |
+
interactive=True
|
| 452 |
+
)
|
| 453 |
+
attach_order_input = gr.Textbox(
|
| 454 |
+
label="Attachment Order",
|
| 455 |
+
placeholder="Attachment Order",
|
| 456 |
+
interactive=True
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Analog generation section
|
| 460 |
+
with gr.Group():
|
| 461 |
+
gr.Markdown("### Generate Analogs")
|
| 462 |
+
current_metric_type = gr.State("")
|
| 463 |
+
|
| 464 |
+
with gr.Row():
|
| 465 |
+
main_cls_dropdown = gr.Dropdown(
|
| 466 |
+
label="Objective",
|
| 467 |
+
choices=["None"] + params["Objective"],
|
| 468 |
+
value="None"
|
| 469 |
+
)
|
| 470 |
+
minor_cls_dropdown = gr.Dropdown(
|
| 471 |
+
label="Metrics",
|
| 472 |
+
choices=["None"],
|
| 473 |
+
value="None"
|
| 474 |
+
)
|
| 475 |
+
number_input = gr.Number(
|
| 476 |
+
label="Number of Analogs",
|
| 477 |
+
value=3,
|
| 478 |
+
step=1,
|
| 479 |
+
minimum=1,
|
| 480 |
+
maximum=10
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
with gr.Row():
|
| 484 |
+
bool_delta = gr.Dropdown(
|
| 485 |
+
choices=params["ImprovementAnticipationBool"],
|
| 486 |
+
label="Target Direction (Boolean)",
|
| 487 |
+
value="0-1",
|
| 488 |
+
visible=False,
|
| 489 |
+
info="Select desired change direction"
|
| 490 |
+
)
|
| 491 |
+
seq_delta = gr.Dropdown(
|
| 492 |
+
choices=params["ImprovementAnticipationSeq"],
|
| 493 |
+
label="Target Range (Sequential)",
|
| 494 |
+
value="(-0.5, 0.0]",
|
| 495 |
+
visible=False,
|
| 496 |
+
info="Select desired value range"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
generate_analogs_btn = gr.Button("Generate Analogs", variant="primary")
|
| 500 |
+
|
| 501 |
+
# Results section
|
| 502 |
+
with gr.Row():
|
| 503 |
+
with gr.Column():
|
| 504 |
+
selected_columns = gr.CheckboxGroup(
|
| 505 |
+
["smile", "molWt", "tpsa", "slogp", "sa", "qed"],
|
| 506 |
+
value=["smile", "molWt", "tpsa", "slogp"],
|
| 507 |
+
label="Select Columns to Display"
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
output_table = gr.Dataframe(
|
| 511 |
+
headers=["smile", "molWt", "tpsa", "slogp", "sa", "qed"],
|
| 512 |
+
label="Generated Analogs"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
with gr.Row():
|
| 516 |
+
download_all_btn = gr.Button("Download All Results", variant="secondary")
|
| 517 |
+
download_selected_btn = gr.Button("Download Selected Results", variant="secondary")
|
| 518 |
+
|
| 519 |
+
# Helper functions for fragment handling
|
| 520 |
+
def process_fragments_response(response_data):
|
| 521 |
+
"""Process the API response into table format"""
|
| 522 |
+
try:
|
| 523 |
+
fragments = response_data.get("fragments", [])
|
| 524 |
+
return [[
|
| 525 |
+
fragment.get("variable_smiles", ""),
|
| 526 |
+
fragment.get("constant_smiles", ""),
|
| 527 |
+
str(fragment.get("attachment_order", ""))
|
| 528 |
+
] for fragment in fragments]
|
| 529 |
+
except Exception as e:
|
| 530 |
+
logger.error(f"Error processing fragments: {str(e)}")
|
| 531 |
+
return []
|
| 532 |
+
|
| 533 |
+
def get_fragments(smiles: str):
|
| 534 |
+
"""
|
| 535 |
+
Get and process fragments from API by calling the fragmentize endpoint.
|
| 536 |
+
Handles multiple fragmentation patterns returned by the API.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
smiles (str): Input SMILES string to fragmentize
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
list: A list of rows where each row represents a possible fragmentation pattern
|
| 543 |
+
"""
|
| 544 |
+
try:
|
| 545 |
+
# URL encode the SMILES string to handle special characters
|
| 546 |
+
encoded_smiles = requests.utils.quote(smiles)
|
| 547 |
+
url = f"{FRAGMENT_ENDPOINT}?smiles={encoded_smiles}"
|
| 548 |
+
logger.info(f"Calling fragmentize API with URL: {url}")
|
| 549 |
+
|
| 550 |
+
response = requests.get(url)
|
| 551 |
+
response.raise_for_status()
|
| 552 |
+
data = response.json()
|
| 553 |
+
|
| 554 |
+
# Process fragments from the response
|
| 555 |
+
fragments = data.get('fragments', [])
|
| 556 |
+
logger.info(f"Found {len(fragments)} possible fragmentations")
|
| 557 |
+
|
| 558 |
+
# Convert each fragment into a table row format
|
| 559 |
+
processed_fragments = []
|
| 560 |
+
for fragment in fragments:
|
| 561 |
+
processed_fragments.append([
|
| 562 |
+
fragment.get('variable_smiles', ''),
|
| 563 |
+
fragment.get('constant_smiles', ''),
|
| 564 |
+
str(fragment.get('attachment_order', ''))
|
| 565 |
+
])
|
| 566 |
+
|
| 567 |
+
return processed_fragments
|
| 568 |
+
|
| 569 |
+
except Exception as e:
|
| 570 |
+
logger.error(f"Error processing fragments: {str(e)}")
|
| 571 |
+
return []
|
| 572 |
+
|
| 573 |
+
def update_selected_fragment(evt: gr.SelectData, fragments_data):
|
| 574 |
+
"""Update fragment fields when table row is selected"""
|
| 575 |
+
try:
|
| 576 |
+
if not fragments_data or evt.index[0] >= len(fragments_data):
|
| 577 |
+
logger.warning("No valid fragment selected")
|
| 578 |
+
return ["", "", ""]
|
| 579 |
+
|
| 580 |
+
selected = fragments_data[evt.index[0]]
|
| 581 |
+
logger.info(f"Selected fragment pattern {evt.index[0]}: var={selected[0]}, const={selected[1]}, order={selected[2]}")
|
| 582 |
+
return [selected[1], selected[0], selected[2]]
|
| 583 |
+
|
| 584 |
+
except Exception as e:
|
| 585 |
+
logger.error(f"Error updating selected fragment: {str(e)}")
|
| 586 |
+
return ["", "", ""]
|
| 587 |
+
|
| 588 |
+
def update_delta_inputs(objective: str, metric: str) -> dict:
|
| 589 |
+
"""
|
| 590 |
+
Updates the visibility and options of delta inputs based on metric type.
|
| 591 |
+
Shows boolean or sequential delta input based on the metric's type.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
objective (str): The selected objective
|
| 595 |
+
metric (str): The selected metric
|
| 596 |
+
|
| 597 |
+
Returns:
|
| 598 |
+
dict: Updates for both delta inputs and the current metric type
|
| 599 |
+
"""
|
| 600 |
+
if not validate_metric_combination(objective, metric):
|
| 601 |
+
return {
|
| 602 |
+
bool_delta: gr.update(visible=False),
|
| 603 |
+
seq_delta: gr.update(visible=False),
|
| 604 |
+
current_metric_type: ""
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
metric_name = get_metric_full_name(objective, metric)
|
| 608 |
+
metric_type = get_metric_type(metric_name)
|
| 609 |
+
|
| 610 |
+
return {
|
| 611 |
+
bool_delta: gr.update(visible=metric_type == "bool"),
|
| 612 |
+
seq_delta: gr.update(visible=metric_type == "seq"),
|
| 613 |
+
current_metric_type: metric_type
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
def update_metrics_dropdown(objective: str) -> dict:
|
| 617 |
+
"""
|
| 618 |
+
Updates the metrics dropdown based on the selected objective.
|
| 619 |
+
Uses the get_metrics_for_objective helper function to get valid metrics for the chosen objective.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
objective (str): The selected objective from the main dropdown
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
dict: A Gradio update object containing the new dropdown configuration
|
| 626 |
+
"""
|
| 627 |
+
metrics = get_metrics_for_objective(objective)
|
| 628 |
+
return gr.Dropdown(choices=metrics, value="None")
|
| 629 |
+
|
| 630 |
+
# Event handlers
|
| 631 |
+
get_ketcher_smiles_btn.click(
|
| 632 |
+
fn=None,
|
| 633 |
+
inputs=None,
|
| 634 |
+
outputs=combined_smiles_input,
|
| 635 |
+
js="async () => { const iframe = document.getElementById('ifKetcher'); if(iframe && iframe.contentWindow && iframe.contentWindow.ketcher) { const smiles = await iframe.contentWindow.ketcher.getSmiles(); return smiles; } else { console.error('Ketcher not ready'); return ''; } }"
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Fragment processing handlers
|
| 639 |
+
fragment_btn.click(
|
| 640 |
+
fn=get_fragments,
|
| 641 |
+
inputs=[combined_smiles_input],
|
| 642 |
+
outputs=[fragments_table]
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
fragments_table.select(
|
| 646 |
+
fn=update_selected_fragment,
|
| 647 |
+
inputs=[fragments_table],
|
| 648 |
+
outputs=[constant_frag_input, variable_frag_input, attach_order_input]
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Metric selection handlers
|
| 652 |
+
main_cls_dropdown.change(
|
| 653 |
+
fn=update_metrics_dropdown,
|
| 654 |
+
inputs=[main_cls_dropdown],
|
| 655 |
+
outputs=[minor_cls_dropdown]
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
main_cls_dropdown.change(
|
| 659 |
+
fn=update_delta_inputs,
|
| 660 |
+
inputs=[main_cls_dropdown, minor_cls_dropdown],
|
| 661 |
+
outputs=[bool_delta, seq_delta, current_metric_type]
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
minor_cls_dropdown.change(
|
| 665 |
+
fn=update_delta_inputs,
|
| 666 |
+
inputs=[main_cls_dropdown, minor_cls_dropdown],
|
| 667 |
+
outputs=[bool_delta, seq_delta, current_metric_type]
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
# Analog generation handler
|
| 671 |
+
generate_analogs_btn.click(
|
| 672 |
+
fn=handle_generate_analogs,
|
| 673 |
+
inputs=[
|
| 674 |
+
main_cls_dropdown,
|
| 675 |
+
minor_cls_dropdown,
|
| 676 |
+
number_input,
|
| 677 |
+
bool_delta,
|
| 678 |
+
seq_delta,
|
| 679 |
+
constant_frag_input,
|
| 680 |
+
variable_frag_input,
|
| 681 |
+
current_metric_type
|
| 682 |
+
],
|
| 683 |
+
outputs=[output_table]
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Download handlers
|
| 687 |
+
download_all_btn.click(
|
| 688 |
+
lambda df: save_to_csv(df, False),
|
| 689 |
+
inputs=[output_table],
|
| 690 |
+
outputs=[gr.File(label="Download CSV")]
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
download_selected_btn.click(
|
| 694 |
+
lambda df, cols: save_to_csv(df[cols], True),
|
| 695 |
+
inputs=[output_table, selected_columns],
|
| 696 |
+
outputs=[gr.File(label="Download CSV")]
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
return demo
|
| 700 |
+
|
| 701 |
+
# Mount the Gradio app
|
| 702 |
+
combined_demo = create_combined_interface()
|
| 703 |
+
app = gr.mount_gradio_app(app, combined_demo, path="/")
|
| 704 |
+
|
| 705 |
+
if __name__ == "__main__":
|
| 706 |
+
uvicorn.run(app, host="127.0.0.1", port=7890)
|