Spaces:
Running
Running
| """Get/put submission results concerning attention from/on COS.""" | |
| import os | |
| import json | |
| import dill | |
| import logging | |
| import numpy as np | |
| from typing import Iterable | |
| from configuration import GENES | |
| from cos import ( | |
| RESULTS_PREFIX, | |
| bytes_from_key, | |
| string_from_key, | |
| bytes_to_key, | |
| ) | |
| from utils import Drug | |
| from plots import embed_barplot | |
| from smiles import smiles_attention_to_svg | |
| logger = logging.getLogger("openapi_server:attention") | |
| def download_attention(workspace_id: str, task_id: str, sample_name: str) -> dict: | |
| """ | |
| Download attention figures and related data. | |
| Args: | |
| workspace_id (str): workspace identifier. | |
| task_id (str): task identifier. | |
| sample_name (str): name of the sample. | |
| Returns: | |
| dict: attention figures and related data. | |
| """ | |
| def _remote_to_bytes(basename: str) -> bytes: | |
| object_name = os.path.join(workspace_id, task_id, sample_name, basename) | |
| key = os.path.join(RESULTS_PREFIX, object_name) | |
| return bytes_from_key(key) | |
| drug_path = os.path.join(workspace_id, task_id, "drug.json") | |
| key = os.path.join(RESULTS_PREFIX, drug_path) | |
| drug = Drug(**json.loads(string_from_key(key))) | |
| logger.debug(f"download attention results from COS for {drug.smiles}.") | |
| # omic | |
| logger.debug("gene attention.") | |
| gene_attention = dill.loads(_remote_to_bytes("gene_attention.pkl")) | |
| genes = np.array(GENES) | |
| order = gene_attention.argsort()[::-1] # descending | |
| gene_attention_js, gene_attention_html = embed_barplot( | |
| genes[order], gene_attention[order] | |
| ) | |
| logger.debug("gene attention plots created.") | |
| # smiles | |
| logger.debug("SMILES attention.") | |
| smiles_attention = dill.loads(_remote_to_bytes("smiles_attention.pkl")) | |
| drug_attention_svg, drug_color_bar_svg = smiles_attention_to_svg( | |
| drug.smiles, smiles_attention | |
| ) | |
| logger.debug("SMILES attention plots created.") | |
| return { | |
| "drug": drug, | |
| "sample_name": sample_name, | |
| "sample_drug_attention_svg": drug_attention_svg, | |
| "sample_drug_color_bar_svg": drug_color_bar_svg, | |
| "sample_gene_attention_js": gene_attention_js, | |
| "sample_gene_attention_html": gene_attention_html, | |
| } | |
| def _upload_ndarray(sample_prefix: str, array: np.ndarray, filename: str) -> None: | |
| bytes_to_key(dill.dumps(array), os.path.join(sample_prefix, f"{filename}.pkl")) | |
| def upload_attention( | |
| prefix: str, | |
| sample_names: Iterable[str], | |
| omic_attention: np.ndarray, | |
| smiles_attention: np.ndarray, | |
| ) -> None: | |
| """ | |
| Upload attention profiles. | |
| Args: | |
| prefix (str): base prefix used as a root. | |
| sample_names (Iterable[str]): name of the samples. | |
| omic_attention (np.ndarray): attention values for genes. | |
| smiles_attention (np.ndarray): attention values for SMILES. | |
| Raises: | |
| ValueError: mismatch in sample names and gene attention. | |
| ValueError: mismatch in sample names and SMILES attention. | |
| ValueError: mismatch in number of genes and gene attention. | |
| """ | |
| omic_entities = np.array(GENES) | |
| # sanity checks | |
| if len(sample_names) != omic_attention.shape[0]: | |
| raise ValueError( | |
| f"length of sample_names {len(sample_names)} does not " | |
| f"match omic_attention {omic_attention.shape[0]}" | |
| ) | |
| if len(sample_names) != len(smiles_attention): | |
| raise ValueError( | |
| f"length of sample_names {len(sample_names)} does not " | |
| f"match smiles_attention {len(smiles_attention)}" | |
| ) | |
| if len(omic_entities) != omic_attention.shape[1]: | |
| raise ValueError( | |
| f"length of omic_entities {len(omic_entities)} " | |
| f"does not match omic_attention.shape[1] {omic_attention.shape[1]}" | |
| ) | |
| # special case first | |
| sample_name = "average" | |
| # omic | |
| res = {} | |
| omic_alphas = omic_attention.mean(axis=0) | |
| res["gene_attention"] = omic_alphas | |
| # smiles | |
| smiles_alphas = smiles_attention.mean(axis=0) | |
| res["smiles_attention"] = smiles_alphas | |
| # logging.debug('uploaded "average" attention figures.') | |
| # for index, sample_name in enumerate(sample_names): | |
| # res[f"gene_attention_{index}"] = omic_attention[index] | |
| # res[f"smiles_attention_{index}"] = smiles_attention[index] | |
| return res | |