Files changed (1) hide show
  1. app.py +706 -0
app.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastapi import FastAPI, HTTPException
3
+ from starlette.staticfiles import StaticFiles
4
+ import uvicorn
5
+ import logging
6
+ from pydantic import BaseModel
7
+ import pandas as pd
8
+ import time
9
+ import requests
10
+ import json
11
+ from typing import List, Dict, Any, Optional, Tuple
12
+
13
+ # Set up logging configuration
14
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # API configurations
18
+ API_BASE_URL = "https://songyou-llm-fastapi.hf.space"
19
+ FRAGMENT_ENDPOINT = f"{API_BASE_URL}/fragmentize"
20
+ GENERATE_ENDPOINT = f"{API_BASE_URL}/generate"
21
+
22
+ # Load parameters from configuration file
23
+ try:
24
+ with open('param.json', 'r') as f:
25
+ params = json.load(f)
26
+ logger.info("Successfully loaded parameter configuration")
27
+ except Exception as e:
28
+ logger.error(f"Error loading parameter configuration: {str(e)}")
29
+ raise
30
+
31
+ # Data models
32
+ class SmilesData(BaseModel):
33
+ """Model for SMILES data received from frontend"""
34
+ smiles: str
35
+
36
+ class GenerateRequest(BaseModel):
37
+ """Request model for generate endpoint with updated fields"""
38
+ constSmiles: str
39
+ varSmiles: str
40
+ mainCls: str
41
+ minorCls: str
42
+ deltaValue: str
43
+ targetName: str = "target1" # default value
44
+ num: int
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+ # Helper functions for metric handling
55
+ def get_metrics_for_objective(objective: str) -> List[str]:
56
+ """Get the corresponding metrics for a given objective"""
57
+ if objective == "None" or objective not in params["Metrics"]:
58
+ return ["None"]
59
+ return ["None"] + params["Metrics"].get(objective, [])
60
+
61
+ def get_metric_full_name(objective: str, metric: str) -> str:
62
+ """
63
+ Constructs the full metric name based on objective and metric.
64
+ For general physical properties, returns just the metric name.
65
+ For others, returns the metric name as is.
66
+ """
67
+ if objective == "general physical properties":
68
+ return metric
69
+ return f"{metric}"
70
+
71
+ def get_metric_type(metric_name: str) -> str:
72
+ """
73
+ Determines if a metric is boolean or sequential based on the BoolOrSeq mapping.
74
+ Returns 'bool', 'seq', or '' if not found.
75
+ """
76
+ metric_type = params["BoolOrSeq"].get(metric_name, "")
77
+ logger.debug(f"Metric type for {metric_name}: {metric_type}")
78
+ return metric_type
79
+
80
+ def get_delta_choices(metric_type: str) -> List[str]:
81
+ """Returns the appropriate choices for delta value based on metric type."""
82
+ if metric_type == "bool":
83
+ return params["ImprovementAnticipationBool"]
84
+ elif metric_type == "seq":
85
+ return params["ImprovementAnticipationSeq"]
86
+ return []
87
+
88
+ def validate_metric_combination(objective: str, metric: str) -> bool:
89
+ """
90
+ Validates if the objective-metric combination is valid.
91
+ Returns True if valid, False otherwise.
92
+ """
93
+ if objective == "None" or metric == "None":
94
+ logger.debug(f"Invalid objective or metric: {objective} - {metric}")
95
+ return False
96
+ if objective not in params["Metrics"]:
97
+ logger.debug(f"Objective not found in metrics: {objective}")
98
+ return False
99
+ if metric not in params["Metrics"].get(objective, []):
100
+ logger.debug(f"Metric not found in objective: {metric}")
101
+ return False
102
+ logger.debug(f"Valid metric combination: {objective} - {metric}")
103
+ return True
104
+
105
+ def handle_generate_analogs(
106
+ main_cls: str,
107
+ minor_cls: str,
108
+ number: int,
109
+ bool_delta_val: str,
110
+ seq_delta_val: str,
111
+ const_smiles: str,
112
+ var_smiles: str,
113
+ metric_type: str
114
+ ) -> pd.DataFrame:
115
+ """
116
+ Handles the generation of analogs with appropriate delta value selection and error handling.
117
+ This function serves as the bridge between the UI and the generate_analogs API call.
118
+
119
+ Args:
120
+ main_cls (str): The main objective classification
121
+ minor_cls (str): The specific metric
122
+ number (int): Number of analogs to generate
123
+ bool_delta_val (str): Selected delta value for boolean metrics
124
+ seq_delta_val (str): Selected delta value for sequential metrics
125
+ const_smiles (str): Constant fragment SMILES
126
+ var_smiles (str): Variable fragment SMILES
127
+ metric_type (str): Type of metric ('bool' or 'seq')
128
+
129
+ Returns:
130
+ pd.DataFrame: DataFrame containing the generated analogs and their properties
131
+ """
132
+ try:
133
+ # Input validation
134
+ if not all([main_cls, minor_cls, const_smiles, var_smiles]):
135
+ logger.error("Missing required inputs")
136
+ return pd.DataFrame()
137
+
138
+ if not validate_metric_combination(main_cls, minor_cls):
139
+ logger.error(f"Invalid metric combination: {main_cls} - {minor_cls}")
140
+ return pd.DataFrame()
141
+
142
+ # Select appropriate delta value based on metric type
143
+ if metric_type not in ["bool", "seq"]:
144
+ logger.error(f"Invalid metric type: {metric_type}")
145
+ return pd.DataFrame()
146
+
147
+ delta_value = bool_delta_val if metric_type == "bool" else seq_delta_val
148
+
149
+ # Generate analogs using the API
150
+ analogs_data = generate_analogs(
151
+ main_cls=main_cls,
152
+ minor_cls=minor_cls,
153
+ number=number,
154
+ delta_value=delta_value,
155
+ const_smiles=const_smiles,
156
+ var_smiles=var_smiles
157
+ )
158
+
159
+ if not analogs_data:
160
+ logger.warning("No analogs generated")
161
+ return pd.DataFrame()
162
+
163
+ return update_output_table(analogs_data)
164
+
165
+ except Exception as e:
166
+ logger.error(f"Error in handle_generate_analogs: {str(e)}")
167
+ return pd.DataFrame()
168
+
169
+ # Update the fragment_molecule function to handle the new response format
170
+ def fragment_molecule(smiles: str) -> Tuple[str, str, str]:
171
+ """
172
+ Call the fragment API endpoint to get molecule fragments
173
+ Returns: List of fragments with their details
174
+ """
175
+ try:
176
+ logger.info(f"Calling fragment API with SMILES: {smiles}")
177
+ response = requests.get(f"{FRAGMENT_ENDPOINT}?smiles={smiles}")
178
+ response.raise_for_status()
179
+ data = response.json()
180
+ logger.info(f"Fragment API response: {data}")
181
+
182
+ # Return empty values if no fragments found
183
+ if not data.get("fragments"):
184
+ return "", "", ""
185
+
186
+ # Return the first fragment by default
187
+ first_fragment = data["fragments"][0]
188
+ return (
189
+ first_fragment.get("constant_smiles", ""),
190
+ first_fragment.get("variable_smiles", ""),
191
+ str(first_fragment.get("attachment_order", ""))
192
+ )
193
+ except Exception as e:
194
+ logger.error(f"Fragment API call failed: {str(e)}")
195
+ return "", "", ""
196
+
197
+ def generate_analogs(
198
+ main_cls: str,
199
+ minor_cls: str,
200
+ number: int,
201
+ delta_value: str,
202
+ const_smiles: str,
203
+ var_smiles: str
204
+ ) -> List[Dict[str, Any]]:
205
+ """
206
+ Generate molecule analogs using the generate API endpoint with improved error handling
207
+ and validation.
208
+ """
209
+ try:
210
+ # Validate inputs
211
+ if not all([const_smiles, var_smiles, main_cls, minor_cls, delta_value]):
212
+ logger.error("Missing required inputs for generate_analogs")
213
+ return []
214
+
215
+ # Create API request
216
+ payload = GenerateRequest(
217
+ constSmiles=const_smiles,
218
+ varSmiles=var_smiles,
219
+ mainCls=main_cls if main_cls != "None" else "",
220
+ minorCls=minor_cls if minor_cls != "None" else "",
221
+ deltaValue=delta_value,
222
+ num=int(number)
223
+ )
224
+
225
+ logger.info(f"Calling generate API with payload: {payload.dict()}")
226
+
227
+ # Make API request
228
+ response = requests.post(
229
+ GENERATE_ENDPOINT,
230
+ headers={'Content-Type': 'application/json'},
231
+ json=payload.dict(),
232
+ timeout=30
233
+ )
234
+
235
+ response.raise_for_status()
236
+ results = response.json()
237
+
238
+ if not isinstance(results, list):
239
+ logger.error(f"Unexpected response format: {results}")
240
+ return []
241
+
242
+ logger.info(f"Successfully generated {len(results)} analogs")
243
+ return results
244
+
245
+ except requests.exceptions.Timeout:
246
+ logger.error("Generate API request timed out")
247
+ return []
248
+ except requests.exceptions.RequestException as e:
249
+ logger.error(f"Generate API request failed: {str(e)}")
250
+ return []
251
+ except Exception as e:
252
+ logger.error(f"Unexpected error in generate_analogs: {str(e)}")
253
+ return []
254
+
255
+ def update_output_table(data: List[Dict[str, Any]]) -> pd.DataFrame:
256
+ """Convert API response data to pandas DataFrame for display"""
257
+ try:
258
+ df = pd.DataFrame(data)
259
+ return df
260
+ except Exception as e:
261
+ logger.error(f"Error creating DataFrame: {str(e)}")
262
+ return pd.DataFrame()
263
+
264
+ def save_to_csv(data: pd.DataFrame, selected_only: bool = False) -> Optional[str]:
265
+ """Save data to CSV file"""
266
+ try:
267
+ filename = f"molecule_analogs_{int(time.time())}.csv"
268
+ data.to_csv(filename, index=False)
269
+ return filename
270
+ except Exception as e:
271
+ logger.error(f"Error saving to CSV: {str(e)}")
272
+ return None
273
+
274
+ # FastAPI app initialization
275
+ app = FastAPI()
276
+
277
+ # Mount Ketcher static files
278
+ app.mount("/ketcher", StaticFiles(directory="ketcher"), name="ketcher")
279
+
280
+ @app.post("/update_smiles")
281
+ async def update_smiles(data: SmilesData):
282
+ """Endpoint to receive SMILES data from frontend"""
283
+ try:
284
+ logger.info(f"Received SMILES from front-end: {data.smiles}")
285
+ return {"status": "ok", "received_smiles": data.smiles}
286
+ except Exception as e:
287
+ logger.error(f"Error processing SMILES update: {str(e)}")
288
+ raise HTTPException(status_code=500, detail=str(e))
289
+
290
+ # Ketcher interface HTML template
291
+ KETCHER_HTML = r'''
292
+ <iframe id="ifKetcher" src="/ketcher/index.html" width="100%" height="600px" style="border: 1px solid #ccc;"></iframe>
293
+
294
+ <script>
295
+ console.log("[Front-end] Ketcher-Gradio integration script loaded.");
296
+
297
+ let ketcher = null;
298
+ let lastSmiles = '';
299
+
300
+ function findSmilesInput() {
301
+ const inputContainer = document.getElementById('combined_smiles_input');
302
+ if (!inputContainer) {
303
+ console.warn("[Front-end] combined_smiles_input element not found.");
304
+ return null;
305
+ }
306
+ const input = inputContainer.querySelector('input[type="text"]');
307
+ return input;
308
+ }
309
+
310
+ function updateGradioInput(smiles) {
311
+ const input = findSmilesInput();
312
+ if (input && input.value !== smiles) {
313
+ input.value = smiles;
314
+ input.dispatchEvent(new Event('input', { bubbles: true }));
315
+ console.log("[Front-end] Updated Gradio input with SMILES:", smiles);
316
+ }
317
+ }
318
+
319
+ async function handleKetcherChange() {
320
+ console.log("[Front-end] handleKetcherChange called, retrieving SMILES...");
321
+ try {
322
+ const smiles = await ketcher.getSmiles({ arom: false });
323
+ console.log("[Front-end] SMILES retrieved from Ketcher:", smiles);
324
+ if (smiles !== lastSmiles) {
325
+ lastSmiles = smiles;
326
+ updateGradioInput(smiles);
327
+
328
+ fetch('/update_smiles', {
329
+ method: 'POST',
330
+ headers: {'Content-Type': 'application/json'},
331
+ body: JSON.stringify({smiles: smiles})
332
+ })
333
+ .then(res => res.json())
334
+ .then(data => {
335
+ console.log("[Front-end] Backend response:", data);
336
+ })
337
+ .catch(err => console.error("[Front-end] Error sending SMILES to backend:", err));
338
+ }
339
+ } catch (err) {
340
+ console.error("[Front-end] Error getting SMILES from Ketcher:", err);
341
+ }
342
+ }
343
+
344
+ function initKetcher() {
345
+ console.log("[Front-end] initKetcher started.");
346
+ const iframe = document.getElementById('ifKetcher');
347
+ if (!iframe) {
348
+ console.error("[Front-end] iframe not found.");
349
+ setTimeout(initKetcher, 500);
350
+ return;
351
+ }
352
+
353
+ const ketcherWindow = iframe.contentWindow;
354
+ if (!ketcherWindow || !ketcherWindow.ketcher) {
355
+ console.log("[Front-end] ketcher not yet available in iframe, retrying...");
356
+ setTimeout(initKetcher, 500);
357
+ return;
358
+ }
359
+
360
+ ketcher = ketcherWindow.ketcher;
361
+ console.log("[Front-end] Ketcher instance acquired:", ketcher);
362
+
363
+ ketcher.setMolecule('C').then(() => {
364
+ console.log("[Front-end] Initial molecule set to 'C'.");
365
+ });
366
+
367
+ const editor = ketcher.editor;
368
+ console.log("[Front-end] Editor object:", editor);
369
+
370
+ let eventBound = false;
371
+ if (editor && typeof editor.subscribe === 'function') {
372
+ console.log("[Front-end] Using editor.subscribe('change', ...)");
373
+ editor.subscribe('change', handleKetcherChange);
374
+ eventBound = true;
375
+ }
376
+
377
+ if (!eventBound) {
378
+ console.error("[Front-end] No suitable event binding found. Check Ketcher version and event API.");
379
+ }
380
+ }
381
+
382
+ document.getElementById('ifKetcher').addEventListener('load', () => {
383
+ console.log("[Front-end] iframe loaded. Initializing Ketcher in 1s...");
384
+ setTimeout(initKetcher, 1000);
385
+ });
386
+ </script>
387
+ '''
388
+
389
+ def create_combined_interface():
390
+ """
391
+ Creates the main Gradio interface combining Ketcher, molecule fragmentation,
392
+ and analog generation functionalities with fragment selection.
393
+ """
394
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
395
+ gr.Markdown("# Fragment Optimization Tools with Ketcher")
396
+
397
+ # Main layout with two columns
398
+ with gr.Row():
399
+ # Left column - Ketcher editor
400
+ with gr.Column(scale=2):
401
+ gr.HTML(KETCHER_HTML)
402
+
403
+ # Right column - Controls and inputs
404
+ with gr.Column(scale=1):
405
+ # SMILES Input section
406
+ with gr.Group():
407
+ gr.Markdown("### Input SMILES (From Ketcher)")
408
+ combined_smiles_input = gr.Textbox(
409
+ label="",
410
+ value="C",
411
+ placeholder="SMILES from Ketcher will appear here",
412
+ elem_id="combined_smiles_input"
413
+ )
414
+ with gr.Row():
415
+ get_ketcher_smiles_btn = gr.Button("Get SMILES from Ketcher", variant="primary")
416
+ fragment_btn = gr.Button("Find Fragments", variant="secondary")
417
+
418
+ # Fragment Selection section
419
+ # Fragment Selection section
420
+ # Fragment Selection section
421
+ with gr.Group():
422
+ gr.Markdown("### Available Fragments")
423
+ gr.Markdown("""
424
+ Select a fragmentation pattern:
425
+ - Variable Fragment: Part that will be modified
426
+ - Constant Fragment: Part that remains unchanged
427
+ - Order: Attachment point pattern between fragments
428
+ """)
429
+ fragments_table = gr.Dataframe(
430
+ headers=["Variable Fragment", "Constant Fragment", "Order"],
431
+ type="array",
432
+ interactive=True,
433
+ label="Click a row to select fragmentation pattern",
434
+ # Remove the invalid parameters
435
+ wrap=True, # Allow text wrapping for long SMILES strings
436
+ row_count=10 # Show 10 rows at a time
437
+ )
438
+
439
+ # Selected Fragment Display
440
+ with gr.Group():
441
+ gr.Markdown("### Selected Fragment")
442
+ with gr.Row():
443
+ constant_frag_input = gr.Textbox(
444
+ label="Constant Fragment",
445
+ placeholder="SMILES of constant fragment",
446
+ interactive=True
447
+ )
448
+ variable_frag_input = gr.Textbox(
449
+ label="Variable Fragment",
450
+ placeholder="SMILES of variable fragment",
451
+ interactive=True
452
+ )
453
+ attach_order_input = gr.Textbox(
454
+ label="Attachment Order",
455
+ placeholder="Attachment Order",
456
+ interactive=True
457
+ )
458
+
459
+ # Analog generation section
460
+ with gr.Group():
461
+ gr.Markdown("### Generate Analogs")
462
+ current_metric_type = gr.State("")
463
+
464
+ with gr.Row():
465
+ main_cls_dropdown = gr.Dropdown(
466
+ label="Objective",
467
+ choices=["None"] + params["Objective"],
468
+ value="None"
469
+ )
470
+ minor_cls_dropdown = gr.Dropdown(
471
+ label="Metrics",
472
+ choices=["None"],
473
+ value="None"
474
+ )
475
+ number_input = gr.Number(
476
+ label="Number of Analogs",
477
+ value=3,
478
+ step=1,
479
+ minimum=1,
480
+ maximum=10
481
+ )
482
+
483
+ with gr.Row():
484
+ bool_delta = gr.Dropdown(
485
+ choices=params["ImprovementAnticipationBool"],
486
+ label="Target Direction (Boolean)",
487
+ value="0-1",
488
+ visible=False,
489
+ info="Select desired change direction"
490
+ )
491
+ seq_delta = gr.Dropdown(
492
+ choices=params["ImprovementAnticipationSeq"],
493
+ label="Target Range (Sequential)",
494
+ value="(-0.5, 0.0]",
495
+ visible=False,
496
+ info="Select desired value range"
497
+ )
498
+
499
+ generate_analogs_btn = gr.Button("Generate Analogs", variant="primary")
500
+
501
+ # Results section
502
+ with gr.Row():
503
+ with gr.Column():
504
+ selected_columns = gr.CheckboxGroup(
505
+ ["smile", "molWt", "tpsa", "slogp", "sa", "qed"],
506
+ value=["smile", "molWt", "tpsa", "slogp"],
507
+ label="Select Columns to Display"
508
+ )
509
+
510
+ output_table = gr.Dataframe(
511
+ headers=["smile", "molWt", "tpsa", "slogp", "sa", "qed"],
512
+ label="Generated Analogs"
513
+ )
514
+
515
+ with gr.Row():
516
+ download_all_btn = gr.Button("Download All Results", variant="secondary")
517
+ download_selected_btn = gr.Button("Download Selected Results", variant="secondary")
518
+
519
+ # Helper functions for fragment handling
520
+ def process_fragments_response(response_data):
521
+ """Process the API response into table format"""
522
+ try:
523
+ fragments = response_data.get("fragments", [])
524
+ return [[
525
+ fragment.get("variable_smiles", ""),
526
+ fragment.get("constant_smiles", ""),
527
+ str(fragment.get("attachment_order", ""))
528
+ ] for fragment in fragments]
529
+ except Exception as e:
530
+ logger.error(f"Error processing fragments: {str(e)}")
531
+ return []
532
+
533
+ def get_fragments(smiles: str):
534
+ """
535
+ Get and process fragments from API by calling the fragmentize endpoint.
536
+ Handles multiple fragmentation patterns returned by the API.
537
+
538
+ Args:
539
+ smiles (str): Input SMILES string to fragmentize
540
+
541
+ Returns:
542
+ list: A list of rows where each row represents a possible fragmentation pattern
543
+ """
544
+ try:
545
+ # URL encode the SMILES string to handle special characters
546
+ encoded_smiles = requests.utils.quote(smiles)
547
+ url = f"{FRAGMENT_ENDPOINT}?smiles={encoded_smiles}"
548
+ logger.info(f"Calling fragmentize API with URL: {url}")
549
+
550
+ response = requests.get(url)
551
+ response.raise_for_status()
552
+ data = response.json()
553
+
554
+ # Process fragments from the response
555
+ fragments = data.get('fragments', [])
556
+ logger.info(f"Found {len(fragments)} possible fragmentations")
557
+
558
+ # Convert each fragment into a table row format
559
+ processed_fragments = []
560
+ for fragment in fragments:
561
+ processed_fragments.append([
562
+ fragment.get('variable_smiles', ''),
563
+ fragment.get('constant_smiles', ''),
564
+ str(fragment.get('attachment_order', ''))
565
+ ])
566
+
567
+ return processed_fragments
568
+
569
+ except Exception as e:
570
+ logger.error(f"Error processing fragments: {str(e)}")
571
+ return []
572
+
573
+ def update_selected_fragment(evt: gr.SelectData, fragments_data):
574
+ """Update fragment fields when table row is selected"""
575
+ try:
576
+ if not fragments_data or evt.index[0] >= len(fragments_data):
577
+ logger.warning("No valid fragment selected")
578
+ return ["", "", ""]
579
+
580
+ selected = fragments_data[evt.index[0]]
581
+ logger.info(f"Selected fragment pattern {evt.index[0]}: var={selected[0]}, const={selected[1]}, order={selected[2]}")
582
+ return [selected[1], selected[0], selected[2]]
583
+
584
+ except Exception as e:
585
+ logger.error(f"Error updating selected fragment: {str(e)}")
586
+ return ["", "", ""]
587
+
588
+ def update_delta_inputs(objective: str, metric: str) -> dict:
589
+ """
590
+ Updates the visibility and options of delta inputs based on metric type.
591
+ Shows boolean or sequential delta input based on the metric's type.
592
+
593
+ Args:
594
+ objective (str): The selected objective
595
+ metric (str): The selected metric
596
+
597
+ Returns:
598
+ dict: Updates for both delta inputs and the current metric type
599
+ """
600
+ if not validate_metric_combination(objective, metric):
601
+ return {
602
+ bool_delta: gr.update(visible=False),
603
+ seq_delta: gr.update(visible=False),
604
+ current_metric_type: ""
605
+ }
606
+
607
+ metric_name = get_metric_full_name(objective, metric)
608
+ metric_type = get_metric_type(metric_name)
609
+
610
+ return {
611
+ bool_delta: gr.update(visible=metric_type == "bool"),
612
+ seq_delta: gr.update(visible=metric_type == "seq"),
613
+ current_metric_type: metric_type
614
+ }
615
+
616
+ def update_metrics_dropdown(objective: str) -> dict:
617
+ """
618
+ Updates the metrics dropdown based on the selected objective.
619
+ Uses the get_metrics_for_objective helper function to get valid metrics for the chosen objective.
620
+
621
+ Args:
622
+ objective (str): The selected objective from the main dropdown
623
+
624
+ Returns:
625
+ dict: A Gradio update object containing the new dropdown configuration
626
+ """
627
+ metrics = get_metrics_for_objective(objective)
628
+ return gr.Dropdown(choices=metrics, value="None")
629
+
630
+ # Event handlers
631
+ get_ketcher_smiles_btn.click(
632
+ fn=None,
633
+ inputs=None,
634
+ outputs=combined_smiles_input,
635
+ js="async () => { const iframe = document.getElementById('ifKetcher'); if(iframe && iframe.contentWindow && iframe.contentWindow.ketcher) { const smiles = await iframe.contentWindow.ketcher.getSmiles(); return smiles; } else { console.error('Ketcher not ready'); return ''; } }"
636
+ )
637
+
638
+ # Fragment processing handlers
639
+ fragment_btn.click(
640
+ fn=get_fragments,
641
+ inputs=[combined_smiles_input],
642
+ outputs=[fragments_table]
643
+ )
644
+
645
+ fragments_table.select(
646
+ fn=update_selected_fragment,
647
+ inputs=[fragments_table],
648
+ outputs=[constant_frag_input, variable_frag_input, attach_order_input]
649
+ )
650
+
651
+ # Metric selection handlers
652
+ main_cls_dropdown.change(
653
+ fn=update_metrics_dropdown,
654
+ inputs=[main_cls_dropdown],
655
+ outputs=[minor_cls_dropdown]
656
+ )
657
+
658
+ main_cls_dropdown.change(
659
+ fn=update_delta_inputs,
660
+ inputs=[main_cls_dropdown, minor_cls_dropdown],
661
+ outputs=[bool_delta, seq_delta, current_metric_type]
662
+ )
663
+
664
+ minor_cls_dropdown.change(
665
+ fn=update_delta_inputs,
666
+ inputs=[main_cls_dropdown, minor_cls_dropdown],
667
+ outputs=[bool_delta, seq_delta, current_metric_type]
668
+ )
669
+
670
+ # Analog generation handler
671
+ generate_analogs_btn.click(
672
+ fn=handle_generate_analogs,
673
+ inputs=[
674
+ main_cls_dropdown,
675
+ minor_cls_dropdown,
676
+ number_input,
677
+ bool_delta,
678
+ seq_delta,
679
+ constant_frag_input,
680
+ variable_frag_input,
681
+ current_metric_type
682
+ ],
683
+ outputs=[output_table]
684
+ )
685
+
686
+ # Download handlers
687
+ download_all_btn.click(
688
+ lambda df: save_to_csv(df, False),
689
+ inputs=[output_table],
690
+ outputs=[gr.File(label="Download CSV")]
691
+ )
692
+
693
+ download_selected_btn.click(
694
+ lambda df, cols: save_to_csv(df[cols], True),
695
+ inputs=[output_table, selected_columns],
696
+ outputs=[gr.File(label="Download CSV")]
697
+ )
698
+
699
+ return demo
700
+
701
+ # Mount the Gradio app
702
+ combined_demo = create_combined_interface()
703
+ app = gr.mount_gradio_app(app, combined_demo, path="/")
704
+
705
+ if __name__ == "__main__":
706
+ uvicorn.run(app, host="127.0.0.1", port=7890)