Update interpretability.py
Browse files- interpretability.py +0 -92
interpretability.py
CHANGED
|
@@ -17,98 +17,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
| 17 |
|
| 18 |
cls_explainer = SequenceClassificationExplainer(model, tokenizer)
|
| 19 |
|
| 20 |
-
pipe = pipeline("text-classification", model=model_name)
|
| 21 |
-
|
| 22 |
-
def get_taste_from_smiles(smiles):
|
| 23 |
-
# Original output
|
| 24 |
-
output = pipe(smiles)
|
| 25 |
-
|
| 26 |
-
# Mapping of labels to tastes
|
| 27 |
-
taste_labels = ['BITTER', 'SOUR', 'SWEET', 'UMAMI', 'UNDEFINED']
|
| 28 |
-
|
| 29 |
-
# Extract label and score
|
| 30 |
-
label_info = output[0]
|
| 31 |
-
label_index = int(label_info['label'].split('_')[1]) # Get the numeric part of the label
|
| 32 |
-
score = label_info['score']
|
| 33 |
-
|
| 34 |
-
# Reassign label
|
| 35 |
-
new_label = taste_labels[label_index]
|
| 36 |
-
|
| 37 |
-
# Format the title string
|
| 38 |
-
title_string = f"{new_label} score: {score:.2f}"
|
| 39 |
-
|
| 40 |
-
# Output the title string
|
| 41 |
-
return title_string
|
| 42 |
-
|
| 43 |
-
def calculate_aspect_ratio(molecule, base_size):
|
| 44 |
-
"""
|
| 45 |
-
Calculates the canvas width and height based on the molecule's aspect ratio.
|
| 46 |
-
|
| 47 |
-
Parameters:
|
| 48 |
-
- molecule (Mol): RDKit molecule object.
|
| 49 |
-
- base_size (int): The base size of the canvas, typically 400.
|
| 50 |
-
|
| 51 |
-
Returns:
|
| 52 |
-
- (int, int): Calculated width and height for the canvas.
|
| 53 |
-
"""
|
| 54 |
-
conf = molecule.GetConformer()
|
| 55 |
-
atom_positions = [conf.GetAtomPosition(i) for i in range(molecule.GetNumAtoms())]
|
| 56 |
-
x_coords = [pos.x for pos in atom_positions]
|
| 57 |
-
y_coords = [pos.y for pos in atom_positions]
|
| 58 |
-
width = max(x_coords) - min(x_coords)
|
| 59 |
-
height = max(y_coords) - min(y_coords)
|
| 60 |
-
aspect_ratio = width / height if height > 0 else 1
|
| 61 |
-
|
| 62 |
-
canvas_width = max(base_size, int(base_size * aspect_ratio)) if aspect_ratio > 1 else base_size
|
| 63 |
-
canvas_height = max(base_size, int(base_size / aspect_ratio)) if aspect_ratio < 1 else base_size
|
| 64 |
-
|
| 65 |
-
return canvas_width, canvas_height
|
| 66 |
-
|
| 67 |
-
def visualize_gradients(smiles, bw=True, padding=0.05):
|
| 68 |
-
"""
|
| 69 |
-
Visualizes atom-wise gradients or importance scores for a given molecule
|
| 70 |
-
based on the SMILES representation as a similarity map.
|
| 71 |
-
|
| 72 |
-
Parameters:
|
| 73 |
-
- smiles (str): The SMILES string of the molecule to visualize.
|
| 74 |
-
- bw (bool): If True, renders the molecule in black and white (default is False).
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
- None: Displays the generated similarity map in the notebook.
|
| 78 |
-
"""
|
| 79 |
-
|
| 80 |
-
print(get_taste_from_smiles(smiles))
|
| 81 |
-
|
| 82 |
-
# Convert SMILES string to RDKit molecule object
|
| 83 |
-
molecule = Chem.MolFromSmiles(smiles)
|
| 84 |
-
Chem.rdDepictor.Compute2DCoords(molecule)
|
| 85 |
-
|
| 86 |
-
# Set up canvas size based on aspect ratio
|
| 87 |
-
base_size = 400
|
| 88 |
-
width, height = calculate_aspect_ratio(molecule, base_size)
|
| 89 |
-
d = Draw.MolDraw2DCairo(width, height)
|
| 90 |
-
#Draw.SetACS1996Mode(d.drawOptions(),Draw.MeanBondLength(molecule))
|
| 91 |
-
d.drawOptions().padding = padding
|
| 92 |
-
|
| 93 |
-
# Optionally set black and white palette
|
| 94 |
-
if bw:
|
| 95 |
-
d.drawOptions().useBWAtomPalette()
|
| 96 |
-
|
| 97 |
-
# Get token importance scores and map to atoms
|
| 98 |
-
token_importance = cls_explainer(smiles)
|
| 99 |
-
atom_importance = [c[1] for c in token_importance if c[0].isalpha()]
|
| 100 |
-
num_atoms = molecule.GetNumAtoms()
|
| 101 |
-
atom_importance = atom_importance[:num_atoms]
|
| 102 |
-
|
| 103 |
-
# Generate and display a similarity map based on atom importance scores
|
| 104 |
-
SimilarityMaps.GetSimilarityMapFromWeights(molecule, atom_importance, draw2d=d)
|
| 105 |
-
|
| 106 |
-
# Convert drawing to image and display
|
| 107 |
-
d.FinishDrawing()
|
| 108 |
-
png_data = d.GetDrawingText()
|
| 109 |
-
img = Image(data=png_data)
|
| 110 |
-
return img
|
| 111 |
-
|
| 112 |
def save_high_quality_png(smiles, title, bw=True, padding=0.05):
|
| 113 |
"""
|
| 114 |
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
|
|
|
|
| 17 |
|
| 18 |
cls_explainer = SequenceClassificationExplainer(model, tokenizer)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def save_high_quality_png(smiles, title, bw=True, padding=0.05):
|
| 21 |
"""
|
| 22 |
Generates a high-quality PNG of atom-wise gradients or importance scores for a molecule.
|