| import torch |
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| from rdkit.Chem.Draw import rdMolDraw2D |
| import numpy as np |
|
|
|
|
| def generate_final_viz(): |
| |
| smiles = "Cc1cc(C)cc(c1)CN2C(=O)N(Cc3cc(C)cc(C)c3)[C@H](Cc4cc(C)cc(C)c4)[C@H](O)[C@@H]2O" |
| mol = Chem.MolFromSmiles(smiles) |
| mol = Chem.AddHs(mol) |
|
|
| |
| AllChem.Compute2DCoords(mol) |
|
|
| |
| |
| |
| raw_attention = { |
| 73: 1.000, 71: 0.968, |
| 33: 0.699, 35: 0.670, |
| 39: 0.484, 43: 0.484, 44: 0.484, 49: 0.484, 53: 0.484, 57: 0.484, |
| 0: 0.479, 4: 0.479 |
| } |
|
|
| |
| num_atoms = mol.GetNumAtoms() |
| aggregated_weights = np.zeros(num_atoms) |
|
|
| for idx, score in raw_attention.items(): |
| atom = mol.GetAtomWithIdx(idx) |
| if atom.GetSymbol() == 'H': |
| |
| neighbors = atom.GetNeighbors() |
| if neighbors: |
| parent_idx = neighbors[0].GetIdx() |
| aggregated_weights[parent_idx] += score |
| else: |
| aggregated_weights[idx] += score |
|
|
| |
| if aggregated_weights.max() > 0: |
| aggregated_weights /= aggregated_weights.max() |
|
|
| |
| highlight_atoms = [] |
| highlight_colors = {} |
|
|
| for i in range(num_atoms): |
| score = aggregated_weights[i] |
| if score > 0.3: |
| highlight_atoms.append(i) |
| |
| highlight_colors[i] = (1.0, 0.0, 0.0, score) |
|
|
| |
| mol_no_h = Chem.RemoveHs(mol) |
| |
| |
| |
| |
|
|
| d = rdMolDraw2D.MolDraw2DCairo(1000, 600) |
| d.drawOptions().annotationFontScale = 0.7 |
| d.drawOptions().addAtomIndices = True |
|
|
| rdMolDraw2D.PrepareAndDrawMolecule(d, mol, |
| highlightAtoms=highlight_atoms, |
| highlightAtomColors=highlight_colors) |
| d.FinishDrawing() |
| d.WriteDrawingText("final_hiv_viz.png") |
| print("✅ Картинка готова: assets/final_hiv_viz.png") |
|
|
|
|
| if __name__ == "__main__": |
| generate_final_viz() |