Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from starlette.staticfiles import StaticFiles | |
| import uvicorn | |
| import logging | |
| from pydantic import BaseModel | |
| import pandas as pd | |
| import time | |
| import requests | |
| import json | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from fastapi.responses import RedirectResponse | |
| # Set up logging configuration | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # API configurations | |
| API_BASE_URL = "https://songyou-llm-fastapi.hf.space" | |
| FRAGMENT_ENDPOINT = f"{API_BASE_URL}/fragmentize" | |
| GENERATE_ENDPOINT = f"{API_BASE_URL}/generate" | |
| # Load parameters from configuration file | |
| try: | |
| with open('param.json', 'r') as f: | |
| params = json.load(f) | |
| logger.info("Successfully loaded parameter configuration") | |
| except Exception as e: | |
| logger.error(f"Error loading parameter configuration: {str(e)}") | |
| raise | |
| # Data models | |
| class SmilesData(BaseModel): | |
| """Model for SMILES data received from frontend""" | |
| smiles: str | |
| class GenerateRequest(BaseModel): | |
| """Request model for generate endpoint with updated fields""" | |
| constSmiles: str | |
| varSmiles: str | |
| mainCls: str | |
| minorCls: str | |
| deltaValue: str | |
| targetName: str = "target1" # default value | |
| num: int | |
| # Helper functions for metric handling | |
| def get_metrics_for_objective(objective: str) -> List[str]: | |
| """Get the corresponding metrics for a given objective""" | |
| if objective == "None" or objective not in params["Metrics"]: | |
| return ["None"] | |
| return ["None"] + params["Metrics"].get(objective, []) | |
| def get_metric_full_name(objective: str, metric: str) -> str: | |
| """ | |
| Constructs the full metric name based on objective and metric. | |
| For general physical properties, returns just the metric name. | |
| For others, returns the metric name as is. | |
| """ | |
| if objective == "general physical properties": | |
| return metric | |
| return f"{metric}" | |
| def get_metric_type(metric_name: str) -> str: | |
| """ | |
| Determines if a metric is boolean or sequential based on the BoolOrSeq mapping. | |
| Returns 'bool', 'seq', or '' if not found. | |
| """ | |
| metric_type = params["BoolOrSeq"].get(metric_name, "") | |
| logger.debug(f"Metric type for {metric_name}: {metric_type}") | |
| return metric_type | |
| def get_delta_choices(metric_type: str) -> List[str]: | |
| """Returns the appropriate choices for delta value based on metric type.""" | |
| if metric_type == "bool": | |
| return params["ImprovementAnticipationBool"] | |
| elif metric_type == "seq": | |
| return params["ImprovementAnticipationSeq"] | |
| return [] | |
| def validate_metric_combination(objective: str, metric: str) -> bool: | |
| """ | |
| Validates if the objective-metric combination is valid. | |
| Returns True if valid, False otherwise. | |
| """ | |
| if objective == "None" or metric == "None": | |
| logger.debug(f"Invalid objective or metric: {objective} - {metric}") | |
| return False | |
| if objective not in params["Metrics"]: | |
| logger.debug(f"Objective not found in metrics: {objective}") | |
| return False | |
| if metric not in params["Metrics"].get(objective, []): | |
| logger.debug(f"Metric not found in objective: {metric}") | |
| return False | |
| logger.debug(f"Valid metric combination: {objective} - {metric}") | |
| return True | |
| def handle_generate_analogs( | |
| main_cls: str, | |
| minor_cls: str, | |
| number: int, | |
| bool_delta_val: str, | |
| seq_delta_val: str, | |
| const_smiles: str, | |
| var_smiles: str, | |
| metric_type: str | |
| ) -> pd.DataFrame: | |
| """ | |
| Handles the generation of analogs with appropriate delta value selection and error handling. | |
| This function serves as the bridge between the UI and the generate_analogs API call. | |
| Args: | |
| main_cls (str): The main objective classification | |
| minor_cls (str): The specific metric | |
| number (int): Number of analogs to generate | |
| bool_delta_val (str): Selected delta value for boolean metrics | |
| seq_delta_val (str): Selected delta value for sequential metrics | |
| const_smiles (str): Constant fragment SMILES | |
| var_smiles (str): Variable fragment SMILES | |
| metric_type (str): Type of metric ('bool' or 'seq') | |
| Returns: | |
| pd.DataFrame: DataFrame containing the generated analogs and their properties | |
| """ | |
| try: | |
| # Input validation | |
| if not all([main_cls, minor_cls, const_smiles, var_smiles]): | |
| logger.error("Missing required inputs") | |
| return pd.DataFrame() | |
| if not validate_metric_combination(main_cls, minor_cls): | |
| logger.error(f"Invalid metric combination: {main_cls} - {minor_cls}") | |
| return pd.DataFrame() | |
| # Select appropriate delta value based on metric type | |
| if metric_type not in ["bool", "seq"]: | |
| logger.error(f"Invalid metric type: {metric_type}") | |
| return pd.DataFrame() | |
| delta_value = bool_delta_val if metric_type == "bool" else seq_delta_val | |
| # Generate analogs using the API | |
| analogs_data = generate_analogs( | |
| main_cls=main_cls, | |
| minor_cls=minor_cls, | |
| number=number, | |
| delta_value=delta_value, | |
| const_smiles=const_smiles, | |
| var_smiles=var_smiles | |
| ) | |
| if not analogs_data: | |
| logger.warning("No analogs generated") | |
| return pd.DataFrame() | |
| return update_output_table(analogs_data) | |
| except Exception as e: | |
| logger.error(f"Error in handle_generate_analogs: {str(e)}") | |
| return pd.DataFrame() | |
| # Update the fragment_molecule function to handle the new response format | |
| def fragment_molecule(smiles: str) -> Tuple[str, str, str]: | |
| """ | |
| Call the fragment API endpoint to get molecule fragments | |
| Returns: List of fragments with their details | |
| """ | |
| try: | |
| logger.info(f"Calling fragment API with SMILES: {smiles}") | |
| response = requests.get(f"{FRAGMENT_ENDPOINT}?smiles={smiles}") | |
| response.raise_for_status() | |
| data = response.json() | |
| logger.info(f"Fragment API response: {data}") | |
| # Return empty values if no fragments found | |
| if not data.get("fragments"): | |
| return "", "", "" | |
| # Return the first fragment by default | |
| first_fragment = data["fragments"][0] | |
| return ( | |
| first_fragment.get("constant_smiles", ""), | |
| first_fragment.get("variable_smiles", ""), | |
| str(first_fragment.get("attachment_order", "")) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Fragment API call failed: {str(e)}") | |
| return "", "", "" | |
| def generate_analogs( | |
| main_cls: str, | |
| minor_cls: str, | |
| number: int, | |
| delta_value: str, | |
| const_smiles: str, | |
| var_smiles: str | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Generate molecule analogs using the generate API endpoint with improved error handling | |
| and validation. | |
| """ | |
| try: | |
| # Validate inputs | |
| if not all([const_smiles, var_smiles, main_cls, minor_cls, delta_value]): | |
| logger.error("Missing required inputs for generate_analogs") | |
| return [] | |
| # Create API request | |
| payload = GenerateRequest( | |
| constSmiles=const_smiles, | |
| varSmiles=var_smiles, | |
| mainCls=main_cls if main_cls != "None" else "", | |
| minorCls=minor_cls if minor_cls != "None" else "", | |
| deltaValue=delta_value, | |
| num=int(number) | |
| ) | |
| logger.info(f"Calling generate API with payload: {payload.dict()}") | |
| # Make API request | |
| response = requests.post( | |
| GENERATE_ENDPOINT, | |
| headers={'Content-Type': 'application/json'}, | |
| json=payload.dict(), | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| results = response.json() | |
| if not isinstance(results, list): | |
| logger.error(f"Unexpected response format: {results}") | |
| return [] | |
| logger.info(f"Successfully generated {len(results)} analogs") | |
| return results | |
| except requests.exceptions.Timeout: | |
| logger.error("Generate API request timed out") | |
| return [] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Generate API request failed: {str(e)}") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Unexpected error in generate_analogs: {str(e)}") | |
| return [] | |
| def update_output_table(data: List[Dict[str, Any]]) -> pd.DataFrame: | |
| """Convert API response data to pandas DataFrame for display""" | |
| try: | |
| df = pd.DataFrame(data) | |
| return df | |
| except Exception as e: | |
| logger.error(f"Error creating DataFrame: {str(e)}") | |
| return pd.DataFrame() | |
| def save_to_csv(data: pd.DataFrame, selected_only: bool = False) -> Optional[str]: | |
| """Save data to CSV file""" | |
| try: | |
| filename = f"molecule_analogs_{int(time.time())}.csv" | |
| data.to_csv(filename, index=False) | |
| return filename | |
| except Exception as e: | |
| logger.error(f"Error saving to CSV: {str(e)}") | |
| return None | |
| # FastAPI app initialization | |
| app = FastAPI() | |
| # Mount Ketcher static files | |
| app.mount("/ketcher", StaticFiles(directory="ketcher"), name="ketcher") | |
| async def update_smiles(data: SmilesData): | |
| """Endpoint to receive SMILES data from frontend""" | |
| try: | |
| logger.info(f"Received SMILES from front-end: {data.smiles}") | |
| return {"status": "ok", "received_smiles": data.smiles} | |
| except Exception as e: | |
| logger.error(f"Error processing SMILES update: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Ketcher interface HTML template | |
| KETCHER_HTML = r''' | |
| <iframe id="ifKetcher" src="/ketcher/index.html" width="100%" height="600px" style="border: 1px solid #ccc;"></iframe> | |
| <script> | |
| console.log("[Front-end] Ketcher-Gradio integration script loaded."); | |
| let ketcher = null; | |
| let lastSmiles = ''; | |
| function findSmilesInput() { | |
| const inputContainer = document.getElementById('combined_smiles_input'); | |
| if (!inputContainer) { | |
| console.warn("[Front-end] combined_smiles_input element not found."); | |
| return null; | |
| } | |
| const input = inputContainer.querySelector('input[type="text"]'); | |
| return input; | |
| } | |
| function updateGradioInput(smiles) { | |
| const input = findSmilesInput(); | |
| if (input && input.value !== smiles) { | |
| input.value = smiles; | |
| input.dispatchEvent(new Event('input', { bubbles: true })); | |
| console.log("[Front-end] Updated Gradio input with SMILES:", smiles); | |
| } | |
| } | |
| async function handleKetcherChange() { | |
| console.log("[Front-end] handleKetcherChange called, retrieving SMILES..."); | |
| try { | |
| const smiles = await ketcher.getSmiles({ arom: false }); | |
| console.log("[Front-end] SMILES retrieved from Ketcher:", smiles); | |
| if (smiles !== lastSmiles) { | |
| lastSmiles = smiles; | |
| updateGradioInput(smiles); | |
| fetch('/update_smiles', { | |
| method: 'POST', | |
| headers: {'Content-Type': 'application/json'}, | |
| body: JSON.stringify({smiles: smiles}) | |
| }) | |
| .then(res => res.json()) | |
| .then(data => { | |
| console.log("[Front-end] Backend response:", data); | |
| }) | |
| .catch(err => console.error("[Front-end] Error sending SMILES to backend:", err)); | |
| } | |
| } catch (err) { | |
| console.error("[Front-end] Error getting SMILES from Ketcher:", err); | |
| } | |
| } | |
| function initKetcher() { | |
| console.log("[Front-end] initKetcher started."); | |
| const iframe = document.getElementById('ifKetcher'); | |
| if (!iframe) { | |
| console.error("[Front-end] iframe not found."); | |
| setTimeout(initKetcher, 500); | |
| return; | |
| } | |
| const ketcherWindow = iframe.contentWindow; | |
| if (!ketcherWindow || !ketcherWindow.ketcher) { | |
| console.log("[Front-end] ketcher not yet available in iframe, retrying..."); | |
| setTimeout(initKetcher, 500); | |
| return; | |
| } | |
| ketcher = ketcherWindow.ketcher; | |
| console.log("[Front-end] Ketcher instance acquired:", ketcher); | |
| ketcher.setMolecule('C').then(() => { | |
| console.log("[Front-end] Initial molecule set to 'C'."); | |
| }); | |
| const editor = ketcher.editor; | |
| console.log("[Front-end] Editor object:", editor); | |
| let eventBound = false; | |
| if (editor && typeof editor.subscribe === 'function') { | |
| console.log("[Front-end] Using editor.subscribe('change', ...)"); | |
| editor.subscribe('change', handleKetcherChange); | |
| eventBound = true; | |
| } | |
| if (!eventBound) { | |
| console.error("[Front-end] No suitable event binding found. Check Ketcher version and event API."); | |
| } | |
| } | |
| document.getElementById('ifKetcher').addEventListener('load', () => { | |
| console.log("[Front-end] iframe loaded. Initializing Ketcher in 1s..."); | |
| setTimeout(initKetcher, 1000); | |
| }); | |
| </script> | |
| ''' | |
| def create_combined_interface(): | |
| """ | |
| Creates the main Gradio interface combining Ketcher, molecule fragmentation, | |
| and analog generation functionalities with fragment selection. | |
| """ | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# Fragment Optimization Tools with Ketcher") | |
| # Main layout with two columns | |
| with gr.Row(): | |
| # Left column - Ketcher editor | |
| with gr.Column(scale=2): | |
| gr.HTML(KETCHER_HTML) | |
| # Right column - Controls and inputs | |
| with gr.Column(scale=1): | |
| # SMILES Input section | |
| with gr.Group(): | |
| gr.Markdown("### Input SMILES (From Ketcher)") | |
| combined_smiles_input = gr.Textbox( | |
| label="", | |
| value="C", | |
| placeholder="SMILES from Ketcher will appear here", | |
| elem_id="combined_smiles_input" | |
| ) | |
| with gr.Row(): | |
| get_ketcher_smiles_btn = gr.Button("Get SMILES from Ketcher", variant="primary") | |
| fragment_btn = gr.Button("Find Fragments", variant="secondary") | |
| # Fragment Selection section | |
| # Fragment Selection section | |
| # Fragment Selection section | |
| with gr.Group(): | |
| gr.Markdown("### Available Fragments") | |
| gr.Markdown(""" | |
| Select a fragmentation pattern: | |
| - Variable Fragment: Part that will be modified | |
| - Constant Fragment: Part that remains unchanged | |
| - Order: Attachment point pattern between fragments | |
| """) | |
| fragments_table = gr.Dataframe( | |
| headers=["Variable Fragment", "Constant Fragment", "Order"], | |
| type="array", | |
| interactive=True, | |
| label="Click a row to select fragmentation pattern", | |
| # Remove the invalid parameters | |
| wrap=True, # Allow text wrapping for long SMILES strings | |
| row_count=10 # Show 10 rows at a time | |
| ) | |
| # Selected Fragment Display | |
| with gr.Group(): | |
| gr.Markdown("### Selected Fragment") | |
| with gr.Row(): | |
| constant_frag_input = gr.Textbox( | |
| label="Constant Fragment", | |
| placeholder="SMILES of constant fragment", | |
| interactive=True | |
| ) | |
| variable_frag_input = gr.Textbox( | |
| label="Variable Fragment", | |
| placeholder="SMILES of variable fragment", | |
| interactive=True | |
| ) | |
| attach_order_input = gr.Textbox( | |
| label="Attachment Order", | |
| placeholder="Attachment Order", | |
| interactive=True | |
| ) | |
| # Analog generation section | |
| with gr.Group(): | |
| gr.Markdown("### Generate Analogs") | |
| current_metric_type = gr.State("") | |
| with gr.Row(): | |
| main_cls_dropdown = gr.Dropdown( | |
| label="Objective", | |
| choices=["None"] + params["Objective"], | |
| value="None" | |
| ) | |
| minor_cls_dropdown = gr.Dropdown( | |
| label="Metrics", | |
| choices=["None"], | |
| value="None" | |
| ) | |
| number_input = gr.Number( | |
| label="Number of Analogs", | |
| value=3, | |
| step=1, | |
| minimum=1, | |
| maximum=10 | |
| ) | |
| with gr.Row(): | |
| bool_delta = gr.Dropdown( | |
| choices=params["ImprovementAnticipationBool"], | |
| label="Target Direction (Boolean)", | |
| value="0-1", | |
| visible=False, | |
| info="Select desired change direction" | |
| ) | |
| seq_delta = gr.Dropdown( | |
| choices=params["ImprovementAnticipationSeq"], | |
| label="Target Range (Sequential)", | |
| value="(-0.5, 0.0]", | |
| visible=False, | |
| info="Select desired value range" | |
| ) | |
| generate_analogs_btn = gr.Button("Generate Analogs", variant="primary") | |
| # Results section | |
| with gr.Row(): | |
| with gr.Column(): | |
| selected_columns = gr.CheckboxGroup( | |
| ["smile", "molWt", "tpsa", "slogp", "sa", "qed"], | |
| value=["smile", "molWt", "tpsa", "slogp"], | |
| label="Select Columns to Display" | |
| ) | |
| output_table = gr.Dataframe( | |
| headers=["smile", "molWt", "tpsa", "slogp", "sa", "qed"], | |
| label="Generated Analogs" | |
| ) | |
| with gr.Row(): | |
| download_all_btn = gr.Button("Download All Results", variant="secondary") | |
| download_selected_btn = gr.Button("Download Selected Results", variant="secondary") | |
| # Helper functions for fragment handling | |
| def process_fragments_response(response_data): | |
| """Process the API response into table format""" | |
| try: | |
| fragments = response_data.get("fragments", []) | |
| return [[ | |
| fragment.get("variable_smiles", ""), | |
| fragment.get("constant_smiles", ""), | |
| str(fragment.get("attachment_order", "")) | |
| ] for fragment in fragments] | |
| except Exception as e: | |
| logger.error(f"Error processing fragments: {str(e)}") | |
| return [] | |
| def get_fragments(smiles: str): | |
| """ | |
| Get and process fragments from API by calling the fragmentize endpoint. | |
| Handles multiple fragmentation patterns returned by the API. | |
| Args: | |
| smiles (str): Input SMILES string to fragmentize | |
| Returns: | |
| list: A list of rows where each row represents a possible fragmentation pattern | |
| """ | |
| try: | |
| # URL encode the SMILES string to handle special characters | |
| encoded_smiles = requests.utils.quote(smiles) | |
| url = f"{FRAGMENT_ENDPOINT}?smiles={encoded_smiles}" | |
| logger.info(f"Calling fragmentize API with URL: {url}") | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| data = response.json() | |
| # Process fragments from the response | |
| fragments = data.get('fragments', []) | |
| logger.info(f"Found {len(fragments)} possible fragmentations") | |
| # Convert each fragment into a table row format | |
| processed_fragments = [] | |
| for fragment in fragments: | |
| processed_fragments.append([ | |
| fragment.get('variable_smiles', ''), | |
| fragment.get('constant_smiles', ''), | |
| str(fragment.get('attachment_order', '')) | |
| ]) | |
| return processed_fragments | |
| except Exception as e: | |
| logger.error(f"Error processing fragments: {str(e)}") | |
| return [] | |
| def update_selected_fragment(evt: gr.SelectData, fragments_data): | |
| """Update fragment fields when table row is selected""" | |
| try: | |
| if not fragments_data or evt.index[0] >= len(fragments_data): | |
| logger.warning("No valid fragment selected") | |
| return ["", "", ""] | |
| selected = fragments_data[evt.index[0]] | |
| logger.info(f"Selected fragment pattern {evt.index[0]}: var={selected[0]}, const={selected[1]}, order={selected[2]}") | |
| return [selected[1], selected[0], selected[2]] | |
| except Exception as e: | |
| logger.error(f"Error updating selected fragment: {str(e)}") | |
| return ["", "", ""] | |
| def update_delta_inputs(objective: str, metric: str) -> dict: | |
| """ | |
| Updates the visibility and options of delta inputs based on metric type. | |
| Shows boolean or sequential delta input based on the metric's type. | |
| Args: | |
| objective (str): The selected objective | |
| metric (str): The selected metric | |
| Returns: | |
| dict: Updates for both delta inputs and the current metric type | |
| """ | |
| if not validate_metric_combination(objective, metric): | |
| return { | |
| bool_delta: gr.update(visible=False), | |
| seq_delta: gr.update(visible=False), | |
| current_metric_type: "" | |
| } | |
| metric_name = get_metric_full_name(objective, metric) | |
| metric_type = get_metric_type(metric_name) | |
| return { | |
| bool_delta: gr.update(visible=metric_type == "bool"), | |
| seq_delta: gr.update(visible=metric_type == "seq"), | |
| current_metric_type: metric_type | |
| } | |
| def update_metrics_dropdown(objective: str) -> dict: | |
| """ | |
| Updates the metrics dropdown based on the selected objective. | |
| Uses the get_metrics_for_objective helper function to get valid metrics for the chosen objective. | |
| Args: | |
| objective (str): The selected objective from the main dropdown | |
| Returns: | |
| dict: A Gradio update object containing the new dropdown configuration | |
| """ | |
| metrics = get_metrics_for_objective(objective) | |
| return gr.Dropdown(choices=metrics, value="None") | |
| # Event handlers | |
| get_ketcher_smiles_btn.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=combined_smiles_input, | |
| js="async () => { const iframe = document.getElementById('ifKetcher'); if(iframe && iframe.contentWindow && iframe.contentWindow.ketcher) { const smiles = await iframe.contentWindow.ketcher.getSmiles(); return smiles; } else { console.error('Ketcher not ready'); return ''; } }" | |
| ) | |
| # Fragment processing handlers | |
| fragment_btn.click( | |
| fn=get_fragments, | |
| inputs=[combined_smiles_input], | |
| outputs=[fragments_table] | |
| ) | |
| fragments_table.select( | |
| fn=update_selected_fragment, | |
| inputs=[fragments_table], | |
| outputs=[constant_frag_input, variable_frag_input, attach_order_input] | |
| ) | |
| # Metric selection handlers | |
| main_cls_dropdown.change( | |
| fn=update_metrics_dropdown, | |
| inputs=[main_cls_dropdown], | |
| outputs=[minor_cls_dropdown] | |
| ) | |
| main_cls_dropdown.change( | |
| fn=update_delta_inputs, | |
| inputs=[main_cls_dropdown, minor_cls_dropdown], | |
| outputs=[bool_delta, seq_delta, current_metric_type] | |
| ) | |
| minor_cls_dropdown.change( | |
| fn=update_delta_inputs, | |
| inputs=[main_cls_dropdown, minor_cls_dropdown], | |
| outputs=[bool_delta, seq_delta, current_metric_type] | |
| ) | |
| # Analog generation handler | |
| generate_analogs_btn.click( | |
| fn=handle_generate_analogs, | |
| inputs=[ | |
| main_cls_dropdown, | |
| minor_cls_dropdown, | |
| number_input, | |
| bool_delta, | |
| seq_delta, | |
| constant_frag_input, | |
| variable_frag_input, | |
| current_metric_type | |
| ], | |
| outputs=[output_table] | |
| ) | |
| # Download handlers | |
| download_all_btn.click( | |
| lambda df: save_to_csv(df, False), | |
| inputs=[output_table], | |
| outputs=[gr.File(label="Download CSV")] | |
| ) | |
| download_selected_btn.click( | |
| lambda df, cols: save_to_csv(df[cols], True), | |
| inputs=[output_table, selected_columns], | |
| outputs=[gr.File(label="Download CSV")] | |
| ) | |
| return demo | |
| # Mount the Gradio app | |
| combined_demo = create_combined_interface() | |
| app = gr.mount_gradio_app(app, combined_demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="127.0.0.1", port=6861) |