Joey / src /ui /handlers.py
Joey Callanan
changes
78b1639
"""
UI Handlers Module
This module contains event handlers and business logic
for the drug discovery application UI components.
"""
from ..molecules.analysis import (
analyze_molecule_image_only,
validate_smiles_realtime,
get_molecule_properties_for_hover,
)
from ..molecules.generated_variations import generate_variations_from_partial_smiles
from ..ai.services import respond, handle_structure_chat, parse_ai_structures
class VariationHandlers:
"""Handles variation-related functionality."""
def __init__(self):
self.current_variations = [] # list of dicts: {smiles, image, style}
self.current_page = 0
self.variations_per_page = 12
# -------------------------------------------------------------
# Generate variations for the UI grid
# -------------------------------------------------------------
def generate_variations_for_display(self, input_smiles, num_variations=12):
"""
Generate variations using Gen_PartialSMILES2.py, convert to RDKit images,
and return formatted gallery items while printing debugging info.
"""
print("\n==============================")
print("🚀 generate_variations_for_display CALLED")
print("==============================")
print(f"User input SMILES: {input_smiles}")
print(f"Requested # variations: {num_variations}")
# ------------------------------------------------------------
# Create big image for the user input molecule
# ------------------------------------------------------------
from rdkit import Chem
from rdkit.Chem import Draw
mol_input = Chem.MolFromSmiles(input_smiles)
if mol_input is None:
print("❌ ERROR: Invalid SMILES input. Cannot generate big preview.")
return [], input_smiles, "Invalid SMILES"
big_image = Draw.MolToImage(mol_input, size=(400, 300))
print("✔ Input molecule rendered successfully.")
# ------------------------------------------------------------
# Run the generation helper
# ------------------------------------------------------------
variations = generate_variations_from_partial_smiles(
input_smiles,
n_to_gen=num_variations
)
print(f"Generator returned {len(variations)} variations.")
# No variations returned ≠ UI failure — show empty grid
if not variations:
print(" No variations generated — returning empty gallery.")
return [], input_smiles, ""
self.current_variations = variations
# ------------------------------------------------------------
# Build gallery items for Gradio
# ------------------------------------------------------------
gallery_items = []
print("\n Building gallery_items...\n")
for i, v in enumerate(variations):
img = v["image"]
smi = v["smiles"]
print(f" #{i+1}:")
print(f" SMILES: {smi}")
print(f" Image type: {type(img)}")
try:
print(f"Image size: {img.size}")
except:
print("Could not read image size")
gallery_items.append((img, smi))
print("\n==============================")
print("Returning data to UI")
print(f"Total gallery items: {len(gallery_items)}")
print("==============================\n")
# ------------------------------------------------------------
# Return to Gradio exactly as before
# ------------------------------------------------------------
return gallery_items, input_smiles, ""
# -------------------------------------------------------------
# Handle selection from the variations grid
# -------------------------------------------------------------
def select_variation(self, evt):
"""Handle selection of a variation from the grid."""
try:
print("=== SELECT_VARIATION CALLED ===")
print(f"Event: {evt}, type: {type(evt)}")
print(f"Current variations count: {len(self.current_variations)}")
if not self.current_variations:
return None, "", "", ""
# Gradio events might carry an index, or evt might be None/int
if evt is None:
index = 0
elif hasattr(evt, "index"):
index = evt.index
elif isinstance(evt, (int, float)):
index = int(evt)
else:
index = 0
# Clamp index
if index < 0 or index >= len(self.current_variations):
index = 0
selected_var = self.current_variations[index]
smiles = selected_var["smiles"]
print(f"Selected SMILES: {smiles}")
properties_text = get_molecule_properties_for_hover(smiles)
# Outputs:
# main_structure_display,
# selected_smiles_display,
# selected_style_display,
# properties_display
return selected_var["image"], smiles, selected_var["style"], properties_text
except Exception as e:
print(f"Error in select_variation: {e}")
import traceback
traceback.print_exc()
return None, "", "", ""
# -------------------------------------------------------------
# Clear variations
# -------------------------------------------------------------
def clear_variations(self):
"""Clear all variations and reset display."""
self.current_variations = []
self.current_page = 0
return [], "", ""
# -------------------------------------------------------------
# Pagination (prev/next buttons)
# -------------------------------------------------------------
def navigate_variations(self, direction):
"""Navigate through variations pages."""
if not self.current_variations:
return [], "Page 1 of 1", None, "", ""
total = len(self.current_variations)
pages = (total + self.variations_per_page - 1) // self.variations_per_page
if direction == "next":
self.current_page = min(self.current_page + 1, pages - 1)
elif direction == "prev":
self.current_page = max(self.current_page - 1, 0)
start = self.current_page * self.variations_per_page
end = min(start + self.variations_per_page, total)
page_variations = self.current_variations[start:end]
gallery_items = [(v["image"], v["smiles"]) for v in page_variations]
page_info = f"Page {self.current_page + 1} of {pages}"
first = page_variations[0] if page_variations else None
return (
gallery_items,
page_info,
first["image"] if first else None,
first["smiles"] if first else "",
first["style"] if first else "",
)
# -------------------------------------------------------------
# Update how many variations per page
# -------------------------------------------------------------
def update_variation_count(self, count):
"""Update the number of variations to generate and show per page."""
self.variations_per_page = count
return count
# -------------------------------------------------------------
# Analyze molecule + tooltip text (used on load / selection)
# -------------------------------------------------------------
def analyze_molecule_with_tooltip(self, smiles):
"""Analyze molecule and return RDKit image and tooltip text."""
molecule_img = analyze_molecule_image_only(smiles)
tooltip_text = get_molecule_properties_for_hover(smiles)
return molecule_img, tooltip_text
class BookmarkHandlers:
"""Handles bookmark functionality."""
def __init__(self):
self.bookmarked_molecules = []
def bookmark_molecule(self, smiles, molecule_name=""):
"""Add a molecule to the bookmarked collection."""
from rdkit import Chem
from rdkit.Chem import Draw
mol = Chem.MolFromSmiles(smiles)
if not mol:
return "❌ Invalid SMILES string - cannot bookmark"
if smiles in [bm['smiles'] for bm in self.bookmarked_molecules]:
return "⚠️ Molecule already bookmarked"
if not molecule_name:
molecule_name = f"Bookmarked_{len(self.bookmarked_molecules) + 1}"
self.bookmarked_molecules.append({
'smiles': smiles,
'name': molecule_name,
'timestamp': len(self.bookmarked_molecules) + 1
})
return f"✅ Bookmarked: {molecule_name}"
def get_bookmarked_molecules(self):
return self.bookmarked_molecules
def remove_bookmark(self, smiles):
self.bookmarked_molecules = [bm for bm in self.bookmarked_molecules if bm['smiles'] != smiles]
return "🗑️ Removed from bookmarks"
def bookmark_current_molecule(self, smiles, name):
"""Bookmark current molecule and update gallery."""
from rdkit import Chem
from rdkit.Chem import Draw
result = self.bookmark_molecule(smiles, name)
bookmarked_mols = self.get_bookmarked_molecules()
gallery_items = []
for mol in bookmarked_mols:
mol_obj = Chem.MolFromSmiles(mol['smiles'])
if mol_obj:
img = Draw.MolToImage(mol_obj, size=(150, 150), kekulize=True)
gallery_items.append((img, f"{mol['name']}: {mol['smiles']}"))
return result, gallery_items
class AIHandler:
"""Handles AI chat functionality with both general questions and structure generation."""
def __init__(self):
self.chat_history = []
def handle_ai_chat(self, message, history, selected_smiles, hf_token, temperature):
"""Handle AI chat with both general questions and structure generation."""
if not message.strip() or not hf_token.strip():
return history, []
history.append({"role": "user", "content": message})
structure_keywords = ['generate', 'create', 'modify', 'derivative', 'variant', 'structure']
is_structure_request = any(keyword in message.lower() for keyword in structure_keywords)
if is_structure_request and selected_smiles:
ai_response = ""
for chunk in respond(
message,
history[:-1],
"You are an expert medicinal chemist. Generate new chemical structures based on user requests.",
512,
temperature,
0.9,
hf_token
):
ai_response = chunk
history.append({"role": "assistant", "content": ai_response})
structures = parse_ai_structures(ai_response, selected_smiles)
return history, structures
else:
ai_response = ""
for chunk in respond(
message,
history[:-1],
"You are an expert medicinal chemist and drug discovery specialist. Help with molecular analysis, drug design, and medicinal chemistry questions.",
512,
temperature,
0.9,
hf_token
):
ai_response = chunk
history.append({"role": "assistant", "content": ai_response})
return history, []