from CGRtools.algorithms.depict import ( Depict, DepictMolecule, DepictCGR, rotate_vector, _render_charge, ) from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer from collections import defaultdict from uuid import uuid4 from math import hypot from functools import partial class WideBondDepictCGR(DepictCGR): """ Like DepictCGR, but all DynamicBonds are drawn 2.5× wider than the standard bond width. """ __slots__ = () def _render_bonds(self): """ Renders the bonds of the CGR as SVG lines, with DynamicBonds drawn wider. This method overrides the base `_render_bonds` to apply a wider stroke to DynamicBonds, highlighting changes in bond order during a reaction. It iterates through all bonds, calculates their positions based on 2D coordinates, and generates SVG `` elements with appropriate styles (color, width, dash array) based on the bond's original (`order`) and primary (`p_order`) states. Aromatic bonds are handled separately using a helper method. Returns: list: A list of strings, where each string is an SVG element representing a bond. """ plane = self._plane config = self._render_config # get the normal width (default 1.0) and compute a 4× wide stroke normal_width = config.get("bond_width", 0.02) wide_width = normal_width * 2.5 broken = config["broken_color"] formed = config["formed_color"] dash1, dash2 = config["dashes"] double_space = config["double_space"] triple_space = config["triple_space"] svg = [] ar_bond_colors = defaultdict(dict) for n, m, bond in self.bonds(): order, p_order = bond.order, bond.p_order nx, ny = plane[n] mx, my = plane[m] # invert Y for SVG ny, my = -ny, -my rv = partial(rotate_vector, 0, x2=mx - nx, y2=ny - my) if order == 1: if p_order == 1: svg.append( f' ' ) elif p_order == 4: ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed svg.append( f' ' ) elif p_order == 2: dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 3: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order is None: svg.append( f' ' ) else: dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) elif order == 4: if p_order == 4: svg.append( f' ' ) elif p_order == 1: ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken svg.append( f' ' ) elif p_order == 2: ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 3: ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order is None: ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken svg.append( f' ' ) else: ar_bond_colors[n][m] = ar_bond_colors[m][n] = None svg.append( f' ' ) elif order == 2: if p_order == 2: dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 1: dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 4: ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 3: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order is None: dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) else: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif order == 3: if p_order == 3: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 1: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 4: ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 2: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order is None: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) else: dx, dy = rv(double_space) dx3 = 3 * dx dy3 = 3 * dy svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif order is None: if p_order == 1: svg.append( f' ' ) elif p_order == 4: ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed svg.append( f' ' ) elif p_order == 2: dx, dy = rv(double_space) # dx = dx // 1.4 # dy = dy // 1.4 svg.append( f' ' ) svg.append( f' ' ) elif p_order == 3: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) else: svg.append( f' ' ) else: if p_order == 8: svg.append( f' ' ) elif p_order == 1: dx, dy = rv(double_space) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 4: ar_bond_colors[n][m] = ar_bond_colors[m][n] = None svg.append( f' ' ) elif p_order == 2: dx, dy = rv(triple_space) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) elif p_order == 3: dx, dy = rv(double_space) dx3 = 3 * dx dy3 = 3 * dy svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) svg.append( f' ' ) else: svg.append( f' ' ) # aromatic rings - unchanged for ring in self.aromatic_rings: cx = sum(plane[x][0] for x in ring) / len(ring) cy = sum(plane[x][1] for x in ring) / len(ring) for n, m in zip(ring, ring[1:]): nx, ny = plane[n] mx, my = plane[m] aromatic = self.__render_aromatic_bond( nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m) ) if aromatic: svg.append(aromatic) n, m = ring[-1], ring[0] nx, ny = plane[n] mx, my = plane[m] aromatic = self.__render_aromatic_bond( nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m) ) if aromatic: svg.append(aromatic) return svg def __render_aromatic_bond(self, n_x, n_y, m_x, m_y, c_x, c_y, color): config = self._render_config dash1, dash2 = config["dashes"] dash3, dash4 = config["aromatic_dashes"] aromatic_space = config["cgr_aromatic_space"] normal_width = config.get("bond_width", 0.02) wide_width = normal_width * 2 # n aligned xy mn_x, mn_y, cn_x, cn_y = m_x - n_x, m_y - n_y, c_x - n_x, c_y - n_y # nm reoriented xy mr_x, mr_y = hypot(mn_x, mn_y), 0 cr_x, cr_y = rotate_vector(cn_x, cn_y, mn_x, -mn_y) if cr_y and aromatic_space / cr_y < 0.65: if cr_y > 0: r_y = aromatic_space else: r_y = -aromatic_space cr_y = -cr_y ar_x = aromatic_space * cr_x / cr_y br_x = mr_x - aromatic_space * (mr_x - cr_x) / cr_y # backward reorienting an_x, an_y = rotate_vector(ar_x, r_y, mn_x, mn_y) bn_x, bn_y = rotate_vector(br_x, r_y, mn_x, mn_y) if color: # print('color') return ( f' ' ) elif color is None: dash3, dash4 = dash1, dash2 return ( f' ' ) def cgr_display(cgr: CGRContainer) -> str: """ Generates an SVG string for displaying a CGR with wider DynamicBonds. This function temporarily modifies the rendering methods of the `CGRContainer` class to use the bond rendering logic from `WideBondDepictCGR`, which draws DynamicBonds with a wider stroke. It cleans the 2D coordinates of the input CGR and then calls its `depict()` method to generate the SVG string using the modified rendering behavior. Args: cgr (CGRContainer): The CGRContainer object to be depicted. Returns: str: An SVG string representing the depiction of the CGR with wider DynamicBonds. """ CGRContainer._CGRContainer__render_aromatic_bond = ( WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond ) CGRContainer._render_bonds = WideBondDepictCGR._render_bonds CGRContainer._WideBondDepictCGR__render_aromatic_bond = ( WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond ) cgr.clean2d() return cgr.depict() class CustomDepictMolecule(DepictMolecule): """ Custom molecule depiction class that uses atom.symbol for rendering. """ def _render_atoms(self): bonds = self._bonds plane = self._plane charges = self._charges radicals = self._radicals hydrogens = self._hydrogens config = self._render_config carbon = config["carbon"] mapping = config["mapping"] span_size = config["span_size"] font_size = config["font_size"] monochrome = config["monochrome"] other_size = config["other_size"] atoms_colors = config["atoms_colors"] mapping_font = config["mapping_size"] dx_m, dy_m = config["dx_m"], config["dy_m"] dx_ci, dy_ci = config["dx_ci"], config["dy_ci"] symbols_font_style = config["symbols_font_style"] # for cumulenes try: # Check if _cumulenes method exists and handle potential errors cumulenes = { y for x in self._cumulenes(heteroatoms=True) if len(x) > 2 for y in x[1:-1] } except AttributeError: cumulenes = set() # Fallback if _cumulenes is not available or fails if monochrome: map_fill = other_fill = "black" else: map_fill = config["mapping_color"] other_fill = config["other_color"] svg = [] maps = [] others = [] font2 = 0.2 * font_size font3 = 0.3 * font_size font4 = 0.4 * font_size font5 = 0.5 * font_size font6 = 0.6 * font_size font7 = 0.7 * font_size font15 = 0.15 * font_size font25 = 0.25 * font_size mask = defaultdict(list) for n, atom in self._atoms.items(): x, y = plane[n] y = -y # --- KEY CHANGE HERE --- # Use atom.symbol if it exists, otherwise fallback to atomic_symbol try: symbol = atom.symbol except AttributeError: symbol = atom.atomic_symbol # Fallback if .symbol doesn't exist # --- END KEY CHANGE --- if ( not bonds.get(n) or symbol != "C" or carbon or atom.charge or atom.is_radical or atom.isotope or n in cumulenes ): # Added bonds.get(n) check for single atoms # Calculate hydrogens if the attribute exists, otherwise default to 0 try: h = hydrogens[n] except (KeyError, AttributeError): h = 0 # Default if _hydrogens is missing or key n is not present if h == 1: h_str = "H" span = "" elif h and h > 1: # Check if h is not None and greater than 1 span = f'{h}' h_str = "H" else: h_str = "" span = "" # Handle charges and radicals safely charge_val = charges.get(n, 0) is_radical = radicals.get(n, False) if charge_val: t = f'{_render_charge.get(charge_val, "")}{"↑" if is_radical else ""}' # Use .get for safety if t: # Only add if charge text is generated others.append( f' ' f"{t}" ) mask["other"].append( f' ' f"{t}" ) elif is_radical: others.append( f' ' ) mask["other"].append( f' ' ) # Handle isotope safely try: iso = atom.isotope if iso: t = iso others.append( f' {t}' ) mask["other"].append( f' {t}' ) except AttributeError: pass # Atom might not have isotope attribute # Determine atom color based on atomic_number, default to black if monochrome or not found atom_color = "black" if not monochrome: try: an = atom.atomic_number if 0 < an <= len(atoms_colors): atom_color = atoms_colors[an - 1] else: atom_color = atoms_colors[ 5 ] # Default to Carbon color if out of range except AttributeError: atom_color = atoms_colors[ 5 ] # Default to Carbon color if no atomic_number svg.append( f' ' ) # Adjust dx based on symbol length for better centering if len(symbol) > 1: dx = font7 dx_mm = dx_m + font5 if symbol[-1].lower() in ( "l", "i", "r", "t", ): # Heuristic for narrow last letters rx = font6 ax = font25 else: rx = font7 ax = font15 mask["center"].append( f' ' ) else: if symbol == "I": # Special case for 'I' dx = font15 dx_mm = dx_m else: # Single character dx = font4 dx_mm = dx_m + font2 mask["center"].append( f' ' ) svg.append( f' {symbol}{h_str}{span}' ) mask["symbols"].append( f' {symbol}{h_str}' ) if span: mask["span"].append( f' ' f"{symbol}{h_str}{span}" ) svg.append(" ") if mapping: maps.append( f' {n}' ) mask["aam"].append( f' {n}' ) elif mapping: # Determine dx_mm for mapping based on symbol length even if atom itself isn't drawn if len(symbol) > 1: dx_mm = dx_m + font5 else: dx_mm = dx_m + font2 if symbol != "I" else dx_m maps.append( f' {n}' ) mask["aam"].append( f' {n}' ) if others: svg.append( f' ' ) svg.extend(others) svg.append(" ") if mapping: svg.append(f' ') svg.extend(maps) svg.append(" ") return svg, mask def depict_custom_reaction(reaction: ReactionContainer): """ Depicts a ReactionContainer using custom atom rendering logic (replace At to X). This function generates an SVG string representing a reaction. It temporarily modifies the classes of the molecules within the reaction to use a custom depiction logic (`CustomDepictMolecule`) that alters how atoms are rendered (specifically, it seems to use `atom.symbol` instead of `atom.atomic_symbol`, potentially for replacing 'At' with 'X' as mentioned in the original comment). After depicting each molecule with the temporary class, it restores the original classes. The function then combines the individual molecule depictions, reaction arrow, and reaction signs into a single SVG. Args: reaction (ReactionContainer): The ReactionContainer object to be depicted. Returns: str: An SVG string representing the depiction of the reaction with custom atom rendering. """ if not reaction._arrow: reaction.fix_positions() # Ensure positions are calculated r_atoms = [] r_bonds = [] r_masks = [] r_max_x = r_max_y = r_min_y = 0 original_classes = {} # Store original classes to restore later try: # Temporarily change the class of molecules to use the custom depiction for mol in reaction.molecules(): if isinstance(mol, (MoleculeContainer, CGRContainer)): original_classes[mol] = mol.__class__ custom_class_name = ( f"TempCustom_{mol.__class__.__name__}_{uuid4().hex}" # Unique name ) # Combine custom depiction with original class methods # Ensure the custom _render_atoms takes precedence new_bases = (CustomDepictMolecule,) + original_classes[mol].__bases__ # Filter out DepictMolecule if it's already a base to avoid MRO issues new_bases = tuple(b for b in new_bases if b is not DepictMolecule) # If DepictMolecule wasn't a direct base, ensure its methods are accessible if CustomDepictMolecule not in original_classes[mol].__mro__: # Prioritize CustomDepictMolecule's methods new_bases = (CustomDepictMolecule, original_classes[mol]) else: # If DepictMolecule was a base, CustomDepictMolecule is already first new_bases = (CustomDepictMolecule,) + tuple( b for b in original_classes[mol].__bases__ if b is not DepictMolecule ) # Create the temporary class mol.__class__ = type(custom_class_name, new_bases, {}) # Depict using the (potentially) modified class atoms, bonds, masks, min_x, min_y, max_x, max_y = mol.depict(embedding=True) r_atoms.append(atoms) r_bonds.append(bonds) r_masks.append(masks) if max_x > r_max_x: r_max_x = max_x if max_y > r_max_y: r_max_y = max_y if min_y < r_min_y: r_min_y = min_y finally: # Restore original classes for mol, original_class in original_classes.items(): mol.__class__ = original_class config = DepictMolecule._render_config # Access via the imported class font_size = config["font_size"] font125 = 1.25 * font_size width = r_max_x + 3.0 * font_size height = r_max_y - r_min_y + 2.5 * font_size viewbox_x = -font125 viewbox_y = -r_max_y - font125 svg = [ f'\n' ' \n \n \n \n \n' f' ' ] sings_plus = reaction._signs if sings_plus: svg.append(f' ') for x in sings_plus: svg.append( f' ' ) svg.append( f' ' ) svg.append(" ") for atoms, bonds, masks in zip(r_atoms, r_bonds, r_masks): # Use the static method from Depict directly svg.extend( Depict._graph_svg(atoms, bonds, masks, viewbox_x, viewbox_y, width, height) ) svg.append("") return "\n".join(svg) def remove_and_shift(nested_dict, to_remove): # Under development """ Removes specified inner keys from a nested dictionary and renumbers the remaining keys. Given a dictionary where values are themselves dictionaries, this function iterates through each inner dictionary. For each inner dictionary, it creates a new dictionary containing only the key-value pairs where the inner key is NOT present in the `to_remove` list. The keys of the remaining elements in the new inner dictionary are then renumbered sequentially starting from 0, effectively removing gaps left by the removed keys. Args: nested_dict (dict): The input nested dictionary (dict of dicts). to_remove (list): A list of keys to remove from the inner dictionaries. Returns: dict: A new nested dictionary with the specified keys removed from inner dictionaries and the remaining inner keys renumbered. """ rem_set = set(to_remove) result = {} for outer_k, inner in nested_dict.items(): new_inner = {} for old_k, v in inner.items(): if old_k in rem_set: continue shift = sum(1 for r in rem_set if r < old_k) new_k = old_k - shift new_inner[new_k] = v result[outer_k] = new_inner return result