fangjiang commited on
Commit
9011565
·
1 Parent(s): 4162a00

update cytof classes and utils

Browse files
app.py CHANGED
@@ -378,12 +378,13 @@ hr {
378
  """
379
 
380
  with gr.Blocks() as demo:
381
- gr.HTML(custom_css)
382
-
383
  cytof_state = gr.State(CytofImage())
 
384
  # used in scenrios where users define/remove channels multiple times
385
  cytof_original_state = gr.State(CytofImage())
386
-
 
 
387
  gr.Markdown('<div class="h-1">Step 1. Upload images</div>')
388
  gr.Markdown('<div class="h-2">You may upload one or two files depending on your use case.</div>')
389
  gr.Markdown('<div class="h-2 bold">Case 1: &nbsp; Upload a single file</div>')
@@ -515,7 +516,7 @@ with gr.Blocks() as demo:
515
  with gr.Column(scale=2):
516
  gr.Markdown('<div class="h-2">This analysis measures the degree of co-expression within a pair of neighborhoods.</div>')
517
  gr.Markdown('<div class="h-2">Select the clustering method:</div>')
518
- info_text = gr.Markdown(update_info_text('K-neighbor'))
519
  cluster_method = gr.Radio(['k-neighbor', 'distance'], value='k-neighbor', elem_classes='test', label='')
520
  cluster_threshold = gr.Slider(minimum=1, maximum=100, step=1, value=30, interactive=True, label='Clustering threshold')
521
  spatial_btn = gr.Button('Run spatial interaction analysis')
@@ -531,8 +532,7 @@ with gr.Blocks() as demo:
531
  gr.Markdown('<br>')
532
  gr.Markdown('<div class="h-1">Step 6. Visualize positive markers</div>')
533
  gr.Markdown('<div class="h-2">Select two markers for side-by-side comparison to visualize their positive states in cells. This serves two purposes: </div>')
534
- gr.Markdown('<div class="h-2 bold">(1) Validate the co-expression analysis results.</div>')
535
- gr.Markdown('<div class="h-2 bold">(2) Validate teh spatial interaction analysis results.</div>')
536
 
537
 
538
  with gr.Row(): # two marker positive visualization - dropdown options
 
378
  """
379
 
380
  with gr.Blocks() as demo:
 
 
381
  cytof_state = gr.State(CytofImage())
382
+
383
  # used in scenrios where users define/remove channels multiple times
384
  cytof_original_state = gr.State(CytofImage())
385
+
386
+ gr.HTML(custom_css)
387
+
388
  gr.Markdown('<div class="h-1">Step 1. Upload images</div>')
389
  gr.Markdown('<div class="h-2">You may upload one or two files depending on your use case.</div>')
390
  gr.Markdown('<div class="h-2 bold">Case 1: &nbsp; Upload a single file</div>')
 
516
  with gr.Column(scale=2):
517
  gr.Markdown('<div class="h-2">This analysis measures the degree of co-expression within a pair of neighborhoods.</div>')
518
  gr.Markdown('<div class="h-2">Select the clustering method:</div>')
519
+ info_text = gr.Markdown(update_info_text('k-neighbor'))
520
  cluster_method = gr.Radio(['k-neighbor', 'distance'], value='k-neighbor', elem_classes='test', label='')
521
  cluster_threshold = gr.Slider(minimum=1, maximum=100, step=1, value=30, interactive=True, label='Clustering threshold')
522
  spatial_btn = gr.Button('Run spatial interaction analysis')
 
532
  gr.Markdown('<br>')
533
  gr.Markdown('<div class="h-1">Step 6. Visualize positive markers</div>')
534
  gr.Markdown('<div class="h-2">Select two markers for side-by-side comparison to visualize their positive states in cells. This serves two purposes: </div>')
535
+ gr.Markdown('<div class="h-2 bold">(1) Validate the co-expression analysis results. (2) Validate teh spatial interaction analysis results.</div>')
 
536
 
537
 
538
  with gr.Row(): # two marker positive visualization - dropdown options
cytof/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (233 Bytes). View file
 
cytof/__pycache__/classes.cpython-38.pyc ADDED
Binary file (57.5 kB). View file
 
cytof/__pycache__/hyperion_preprocess.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
cytof/__pycache__/hyperion_segmentation.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
cytof/__pycache__/segmentation_functions.cpython-38.pyc ADDED
Binary file (21.6 kB). View file
 
cytof/__pycache__/utils.cpython-38.pyc ADDED
Binary file (15.6 kB). View file
 
cytof/classes.py CHANGED
@@ -137,7 +137,7 @@ class CytofImage():
137
  self.df = pd.concat([self.df, df2])
138
 
139
  def quality_control(self, thres: int = 50) -> None:
140
- setattr(self, "keep", False)
141
  if (max(self.df['X']) < thres) \
142
  or (max(self.df['Y']) < thres):
143
  print("At least one dimension of the image {}-{} is smaller than {}, exclude from analyzing" \
@@ -488,7 +488,6 @@ class CytofImage():
488
  # attach quantile dictionary to self
489
  self.dict_quantiles = quantiles
490
 
491
- print('dict quantiles:', quantiles)
492
  # return quantiles
493
 
494
  def _vis_normalization(self, savename: Optional[str] = None):
@@ -979,8 +978,19 @@ class CytofImageTiff(CytofImage):
979
  return new_instance
980
 
981
  def quality_control(self, thres: int = 50) -> None:
982
- setattr(self, "keep", False)
983
- if any([x < thres for x in self.image.shape]):
 
 
 
 
 
 
 
 
 
 
 
984
  print(f"At least one dimension of the image {self.slide}-{self.roi} is smaller than {thres}, \
985
  hence exclude from analyzing" )
986
  self.keep = False
@@ -1157,7 +1167,7 @@ def apply_threshold_to_column(column, threshold):
1157
  class CytofCohort():
1158
  def __init__(self, cytof_images: Optional[dict] = None,
1159
  df_cohort: Optional[pd.DataFrame] = None,
1160
- dir_out: str = "./",
1161
  cohort_name: str = "cohort1"):
1162
  """
1163
  cytof_images:
@@ -1170,13 +1180,15 @@ class CytofCohort():
1170
  "cell_sum": ["cell_sum", "cell_morphology"],
1171
  "cell_ave": ["cell_ave", "cell_morphology"],
1172
  "cell_sum_only": ["cell_sum"],
1173
- "cell_ave_only": ["cell_ave"]
1174
- }
1175
 
1176
- self.name = cohort_name
1177
- self.dir_out = os.path.join(dir_out, self.name)
1178
- if not os.path.exists(self.dir_out):
1179
- os.makedirs(self.dir_out)
 
 
1180
  def __getitem__(self, key):
1181
  'Extracts a particular cytof image from the cohort'
1182
  return self.cytof_images[key]
@@ -1187,12 +1199,16 @@ class CytofCohort():
1187
  def __repr__(self):
1188
  return f"CytofCohort(name={self.name})"
1189
 
1190
- def save_cytof_cohort(self, savename):
1191
- directory = os.path.dirname(savename)
1192
- if not os.path.exists(directory):
1193
- os.makedirs(directory)
1194
- pkl.dump(self, open(savename, "wb"))
1195
-
 
 
 
 
1196
  def batch_process_feature(self):
1197
  """
1198
  Batch process: if the CytofCohort is initialized by a dictionary of CytofImages
@@ -1204,8 +1220,9 @@ class CytofCohort():
1204
  setattr(self, "dict_feat", cytof_img.features)
1205
  if not hasattr(self, "markers"):
1206
  setattr(self, "markers", cytof_img.markers)
 
 
1207
 
1208
- print('dict quantiles in batch process:', cytof_img.dict_quantiles)
1209
  try:
1210
  qs &= set(list(cytof_img.dict_quantiles.keys()))
1211
  except:
@@ -1226,24 +1243,36 @@ class CytofCohort():
1226
  def batch_process(self, params: Dict):
1227
  sys.path.append("../CLIscripts")
1228
  from process_single_roi import process_single, SetParameters
 
 
1229
  for i, (slide, roi, fname) in self.df_cohort.iterrows():
1230
  paramsi = SetParameters(filename=fname,
1231
- outdir=self.dir_out,
1232
- label_marker_file=params.get('label_marker_file', None),
1233
- slide=slide,
1234
- roi=roi,
1235
- quality_control_thres=params.get("quality_control_thres", 50),
1236
- channels_remove=params.get("channels_remove", None),
1237
- channels_dict=params.get("channels_dict", None),
1238
- use_membrane=params.get("use_membrane",True),
1239
- cell_radius=params.get("cell_radius", 5),
1240
- normalize_qs=params.get("normalize_qs", 75),
1241
- iltype=params.get('iltype', None))
1242
-
1243
- cytof_img = process_single(paramsi, downstream_analysis=False, verbose=False)
1244
- self.cytof_images[f"{slide}_{roi}"] = cytof_img
 
 
 
 
 
 
 
 
 
 
 
1245
 
1246
- self.batch_process_feature()
1247
 
1248
  def get_feature(self,
1249
  normq: int = 75,
@@ -1310,12 +1339,12 @@ class CytofCohort():
1310
  normq: int = 75,
1311
  feat_type: str = "normed_scaled",
1312
  feat_set: str = "all",
1313
- markers: str = "all",
1314
  verbose: bool = False):
1315
 
1316
  assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1317
- assert (markers == "all" or isinstance(markers, list))
1318
- assert feat_set in self.feat_sets.keys(), f"feature set {feat_set} not supported!"
1319
 
1320
  description = "original" if feat_type=="" else f"{normq}{feat_type}"
1321
  n_attr = f"df_feature{feat_type}" if feat_type=="" else f"df_feature_{normq}{feat_type}" # the attribute name to achieve from cytof_img
@@ -1330,15 +1359,22 @@ class CytofCohort():
1330
  if "morphology" in y:
1331
  feat_names += self.dict_feat[y]
1332
  else:
1333
- if markers == "all": # features extracted from all markers are kept
1334
- feat_names += self.dict_feat[y]
1335
- markers = self.markers
 
 
 
 
 
 
1336
  else: # only features correspond to markers kept (markers are a subset of self.markers)
1337
- ids = [self.markers.index(x) for x in markers] # TODO: the case where marker in markers not in self.markers???
1338
  feat_names += [self.dict_feat[y][x] for x in ids]
1339
-
 
1340
  df_feature = getattr(self, n_attr)[feat_names]
1341
- return df_feature, markers, feat_names, description, n_attr
1342
 
1343
  ###############################################################
1344
  ################## PhenoGraph Clustering ######################
@@ -1347,21 +1383,46 @@ class CytofCohort():
1347
  normq:int = 75,
1348
  feat_type:str = "normed_scaled",
1349
  feat_set: str = "all",
1350
- pheno_markers: Union[str, List] = "all",
1351
  k: int = None,
1352
  save_vis: bool = False,
1353
  verbose:bool = True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1354
 
1355
- if pheno_markers == "all":
1356
- pheno_markers_ = "_all"
1357
  else:
1358
- pheno_markers_ = "_subset1"
1359
 
1360
  assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1361
- df_feature, pheno_markers, feat_names, description, n_attr = self._get_feature_subset(normq=normq,
1362
  feat_type=feat_type,
1363
  feat_set=feat_set,
1364
- markers=pheno_markers,
1365
  verbose=verbose)
1366
  # set number of nearest neighbors k and run PhenoGraph for phenotype clustering
1367
  k = k if k else int(df_feature.shape[0] / 100)
@@ -1381,13 +1442,13 @@ class CytofCohort():
1381
  if not hasattr(self, "phenograph"):
1382
  setattr(self, "phenograph", {})
1383
  key_pheno = f"{description}_{feat_set}_feature_{k}"
1384
- key_pheno += f"{pheno_markers_}_markers"
1385
 
1386
 
1387
  N = len(np.unique(communities))
1388
  self.phenograph[key_pheno] = {
1389
  "data": df_feature,
1390
- "markers": pheno_markers,
1391
  "features": feat_names,
1392
  "description": {"normalization": description, "feature_set": feat_set}, # normalization and/or scaling | set of feature (in self.feat_sets)
1393
  "communities": communities,
@@ -1428,7 +1489,8 @@ class CytofCohort():
1428
  save_vis: bool = False,
1429
  show_plots: bool = False,
1430
  plot_together: bool = True,
1431
- fig_width: int = 5 # only when plot_together is True
 
1432
  ):
1433
  assert level.upper() in ["COHORT", "SLIDE", "ROI"], "Only 'cohort', 'slide' and 'roi' are accetable values for level"
1434
  this_pheno = self.phenograph[key_pheno]
@@ -1485,15 +1547,17 @@ class CytofCohort():
1485
  fig, axs = plt.subplots(1,ncol, figsize=(ncol*fig_width, fig_width))
1486
  proj_2d = proj_2ds[key]
1487
  commu = commus[key]
 
1488
  # Visualize 1: plot 2d projection together
1489
  print("Visualization in 2d - {}-{}".format(level, key))
1490
  savename = os.path.join(vis_savedir, f"cluster_scatter_{level}_{key}.png") if (save_vis and not plot_together) else None
1491
  ax = axs[0] if plot_together else None
1492
- fig_scatter = visualize_scatter(data=proj_2d, communities=commu, n_community=n_community,
1493
- title=key, savename=savename, show=show_plots, ax=ax)
1494
- figs_scatter[key] = fig_scatter
1495
 
1496
  figs_exps[key] = {}
 
1497
  # Visualize 2: protein expression
1498
  for axid, acm_tpe in enumerate(accumul_type):
1499
  ids = [i for (i, x) in enumerate(feat_names) if re.search(".{}".format(acm_tpe), x)]
@@ -1526,10 +1590,10 @@ class CytofCohort():
1526
  if (save_vis and not plot_together) else None
1527
  vis_exp = cluster_protein_exp_norm if normalize else cluster_protein_exp
1528
  ax = axs[axid+1] if plot_together else None
1529
- fig_exps = visualize_expression(data=vis_exp, markers=markers,
1530
  group_ids=group_ids, title="{} - {}-{}".format(level, acm_tpe, key),
1531
  savename=savename, show=show_plots, ax=ax)
1532
- figs_exps[key][acm_tpe] = fig_exps
1533
  cluster_protein_exps[key] = vis_exp
1534
  plt.tight_layout()
1535
  if plot_together:
@@ -1892,3 +1956,38 @@ class CytofCohort():
1892
  slide_co_expression_dict[slide_key] = (edge_percentage_norm, df_expected.columns)
1893
 
1894
  return slide_co_expression_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  self.df = pd.concat([self.df, df2])
138
 
139
  def quality_control(self, thres: int = 50) -> None:
140
+ setattr(self, "keep", True)
141
  if (max(self.df['X']) < thres) \
142
  or (max(self.df['Y']) < thres):
143
  print("At least one dimension of the image {}-{} is smaller than {}, exclude from analyzing" \
 
488
  # attach quantile dictionary to self
489
  self.dict_quantiles = quantiles
490
 
 
491
  # return quantiles
492
 
493
  def _vis_normalization(self, savename: Optional[str] = None):
 
978
  return new_instance
979
 
980
  def quality_control(self, thres: int = 50) -> None:
981
+ setattr(self, "keep", True)
982
+
983
+ shape = self.image.shape
984
+
985
+ if len(shape) == 2:
986
+ # Just height and width
987
+ dims_to_check = shape
988
+ else:
989
+ # Assume the channel dimension is the smallest one
990
+ channel_dim = min(shape)
991
+ dims_to_check = [d for d in shape if d != channel_dim]
992
+
993
+ if any(x < thres for x in dims_to_check):
994
  print(f"At least one dimension of the image {self.slide}-{self.roi} is smaller than {thres}, \
995
  hence exclude from analyzing" )
996
  self.keep = False
 
1167
  class CytofCohort():
1168
  def __init__(self, cytof_images: Optional[dict] = None,
1169
  df_cohort: Optional[pd.DataFrame] = None,
1170
+ dir_out: Optional[str] = "./",
1171
  cohort_name: str = "cohort1"):
1172
  """
1173
  cytof_images:
 
1180
  "cell_sum": ["cell_sum", "cell_morphology"],
1181
  "cell_ave": ["cell_ave", "cell_morphology"],
1182
  "cell_sum_only": ["cell_sum"],
1183
+ "cell_ave_only": ["cell_ave"],
1184
+ } # need at least cell sum/ave; cannot be just morphology; can't make cluster v. channel heatmap
1185
 
1186
+ self.name = cohort_name
1187
+ self.dir_out = os.path.join(dir_out, self.name) if isinstance(dir_out, str) else None
1188
+ if self.dir_out:
1189
+ os.makedirs(self.dir_out, exist_ok=True)
1190
+ print('Output folder created:', self.dir_out)
1191
+
1192
  def __getitem__(self, key):
1193
  'Extracts a particular cytof image from the cohort'
1194
  return self.cytof_images[key]
 
1199
  def __repr__(self):
1200
  return f"CytofCohort(name={self.name})"
1201
 
1202
+ def save_cytof_cohort(self):
1203
+ if self.dir_out:
1204
+ save_path = f'{os.path.join(self.dir_out, self.name)}.pkl'
1205
+ pkl.dump(self, open(save_path, "wb"))
1206
+
1207
+ return save_path
1208
+ else:
1209
+ raise FileNotFoundError('self.dir_out not specified')
1210
+
1211
+
1212
  def batch_process_feature(self):
1213
  """
1214
  Batch process: if the CytofCohort is initialized by a dictionary of CytofImages
 
1220
  setattr(self, "dict_feat", cytof_img.features)
1221
  if not hasattr(self, "markers"):
1222
  setattr(self, "markers", cytof_img.markers)
1223
+ if not hasattr(self, "channels"):
1224
+ setattr(self, "channels", cytof_img.channels)
1225
 
 
1226
  try:
1227
  qs &= set(list(cytof_img.dict_quantiles.keys()))
1228
  except:
 
1243
  def batch_process(self, params: Dict):
1244
  sys.path.append("../CLIscripts")
1245
  from process_single_roi import process_single, SetParameters
1246
+
1247
+ success_rows = []
1248
  for i, (slide, roi, fname) in self.df_cohort.iterrows():
1249
  paramsi = SetParameters(filename=fname,
1250
+ outdir=self.dir_out,
1251
+ label_marker_file=params.get('label_marker_file', None),
1252
+ slide=slide,
1253
+ roi=roi,
1254
+ quality_control_thres=params.get("quality_control_thres", 50),
1255
+ channels_remove=params.get("channels_remove", None),
1256
+ channels_dict=params.get("channels_dict", None),
1257
+ use_membrane=params.get("use_membrane",True),
1258
+ cell_radius=params.get("cell_radius", 5),
1259
+ normalize_qs=params.get("normalize_qs", 75),
1260
+ iltype=params.get('iltype', None))
1261
+
1262
+ try:
1263
+ cytof_img = process_single(paramsi, downstream_analysis=False, verbose=False)
1264
+ self.cytof_images[f"{slide}_{roi}"] = cytof_img
1265
+
1266
+ # image successfully processed, record index
1267
+ success_rows.append(i)
1268
+
1269
+ except Exception as e:
1270
+ print(f"Skipping {slide}_{roi} due to error: {e}")
1271
+ continue
1272
+
1273
+ # update df_cohort to contain only successfully calculated rows
1274
+ self.df_cohort = self.df_cohort.loc[success_rows].reset_index(drop=True)
1275
 
 
1276
 
1277
  def get_feature(self,
1278
  normq: int = 75,
 
1339
  normq: int = 75,
1340
  feat_type: str = "normed_scaled",
1341
  feat_set: str = "all",
1342
+ channels: Union[str, List] = "all",
1343
  verbose: bool = False):
1344
 
1345
  assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1346
+ assert (channels == "all" or set(channels).issubset(set(self.channels))), f"input channels {channels} not a subset of self.channels"
1347
+ assert feat_set in self.feat_sets.keys() , f"feature set {feat_set} not supported!"
1348
 
1349
  description = "original" if feat_type=="" else f"{normq}{feat_type}"
1350
  n_attr = f"df_feature{feat_type}" if feat_type=="" else f"df_feature_{normq}{feat_type}" # the attribute name to achieve from cytof_img
 
1359
  if "morphology" in y:
1360
  feat_names += self.dict_feat[y]
1361
  else:
1362
+ if channels == "all": # features extracted from all channels are kept
1363
+ feat_names += self.dict_feat[y]
1364
+ channels_return = self.channels.copy() # return all channel names except nuclei and membrane
1365
+ channels_return.remove('nuclei') # all instances have nuclei channel
1366
+ try:
1367
+ channels_return.remove('membrane') # some might not have membrane
1368
+ except ValueError:
1369
+ pass
1370
+
1371
  else: # only features correspond to markers kept (markers are a subset of self.markers)
1372
+ ids = [self.channels.index(x) for x in channels]
1373
  feat_names += [self.dict_feat[y][x] for x in ids]
1374
+ channels_return = channels.copy() # return only subset
1375
+
1376
  df_feature = getattr(self, n_attr)[feat_names]
1377
+ return df_feature, channels_return, feat_names, description, n_attr
1378
 
1379
  ###############################################################
1380
  ################## PhenoGraph Clustering ######################
 
1383
  normq:int = 75,
1384
  feat_type:str = "normed_scaled",
1385
  feat_set: str = "all",
1386
+ pheno_channels: Union[str, List] = "all",
1387
  k: int = None,
1388
  save_vis: bool = False,
1389
  verbose:bool = True):
1390
+ """performs PhenoGraph clustering on normalized and/or scaled features
1391
+
1392
+ Parameters
1393
+ ----------
1394
+ normq : int, optional
1395
+ xth quantile of normalization; for finding df_feature attribute, by default 75
1396
+ feat_type : str, optional
1397
+ for finding df_feature attribute for PhenoGraph, by default "normed_scaled"
1398
+ feat_set : str, optional
1399
+ element in [cell_sum, cell_ave, cell_sum_only, cell_ave_only, all]; all will include all aforementioned feature sets, by default "all"
1400
+ pheno_channels : Union[str, List], optional
1401
+ list of channels used for PhenoGraph, by default "all"
1402
+ k : int, optional
1403
+ k neighbors, by default None
1404
+ save_vis : bool, optional
1405
+ whether to save viasualization, by default False
1406
+ verbose : bool, optional
1407
+ whether to print progress details, by default True
1408
+
1409
+ Returns
1410
+ -------
1411
+ key_pheno
1412
+ string literal that can be indexed in self.phenograph
1413
+ """
1414
+
1415
 
1416
+ if pheno_channels == "all":
1417
+ pheno_channels_ = "_all"
1418
  else:
1419
+ pheno_channels_ = "_subset1"
1420
 
1421
  assert feat_type in ["normed_scaled", "normed", ""], f"feature type {feat_type} not supported!"
1422
+ df_feature, channels, feat_names, description, n_attr = self._get_feature_subset(normq=normq,
1423
  feat_type=feat_type,
1424
  feat_set=feat_set,
1425
+ channels=pheno_channels,
1426
  verbose=verbose)
1427
  # set number of nearest neighbors k and run PhenoGraph for phenotype clustering
1428
  k = k if k else int(df_feature.shape[0] / 100)
 
1442
  if not hasattr(self, "phenograph"):
1443
  setattr(self, "phenograph", {})
1444
  key_pheno = f"{description}_{feat_set}_feature_{k}"
1445
+ key_pheno += f"{pheno_channels_}_markers"
1446
 
1447
 
1448
  N = len(np.unique(communities))
1449
  self.phenograph[key_pheno] = {
1450
  "data": df_feature,
1451
+ "markers": channels, # preserve key for downstream
1452
  "features": feat_names,
1453
  "description": {"normalization": description, "feature_set": feat_set}, # normalization and/or scaling | set of feature (in self.feat_sets)
1454
  "communities": communities,
 
1489
  save_vis: bool = False,
1490
  show_plots: bool = False,
1491
  plot_together: bool = True,
1492
+ fig_width: int = 5, # only when plot_together is True,
1493
+ scatter_dot_size: int = 2
1494
  ):
1495
  assert level.upper() in ["COHORT", "SLIDE", "ROI"], "Only 'cohort', 'slide' and 'roi' are accetable values for level"
1496
  this_pheno = self.phenograph[key_pheno]
 
1547
  fig, axs = plt.subplots(1,ncol, figsize=(ncol*fig_width, fig_width))
1548
  proj_2d = proj_2ds[key]
1549
  commu = commus[key]
1550
+
1551
  # Visualize 1: plot 2d projection together
1552
  print("Visualization in 2d - {}-{}".format(level, key))
1553
  savename = os.path.join(vis_savedir, f"cluster_scatter_{level}_{key}.png") if (save_vis and not plot_together) else None
1554
  ax = axs[0] if plot_together else None
1555
+ fig_scatter, ax_scatter = visualize_scatter(data=proj_2d, communities=commu, n_community=n_community,
1556
+ title=key, scatter_dot_size=scatter_dot_size, savename=savename, show=show_plots, ax=ax)
1557
+ figs_scatter[key] = (fig_scatter, ax_scatter)
1558
 
1559
  figs_exps[key] = {}
1560
+
1561
  # Visualize 2: protein expression
1562
  for axid, acm_tpe in enumerate(accumul_type):
1563
  ids = [i for (i, x) in enumerate(feat_names) if re.search(".{}".format(acm_tpe), x)]
 
1590
  if (save_vis and not plot_together) else None
1591
  vis_exp = cluster_protein_exp_norm if normalize else cluster_protein_exp
1592
  ax = axs[axid+1] if plot_together else None
1593
+ fig_exps, ax_exps = visualize_expression(data=vis_exp, markers=markers,
1594
  group_ids=group_ids, title="{} - {}-{}".format(level, acm_tpe, key),
1595
  savename=savename, show=show_plots, ax=ax)
1596
+ figs_exps[key][acm_tpe] = (fig_exps, ax_exps)
1597
  cluster_protein_exps[key] = vis_exp
1598
  plt.tight_layout()
1599
  if plot_together:
 
1956
  slide_co_expression_dict[slide_key] = (edge_percentage_norm, df_expected.columns)
1957
 
1958
  return slide_co_expression_dict
1959
+
1960
+
1961
+ def cohort_interaction_graphs(self, feature_name, accumul_type, method: str = "distance", threshold=50):
1962
+ assert method in ["distance", "k-neighbor"], "Method can be either 'distance' or 'k-neighbor'!"
1963
+
1964
+ # used to store ROI-level interaction graphs
1965
+ marker_roi_list = list()
1966
+
1967
+ for roi_keys, cytof_img in self.cytof_images.items():
1968
+ print(f"Processing ROI {roi_keys}")
1969
+ df_expected_prob, df_cell_interaction_prob = cytof_img.roi_interaction_graphs(feature_name=feature_name, accumul_type=accumul_type, method=method, threshold=threshold, return_components=False)
1970
+
1971
+ # do some post processing
1972
+ marker_all = df_expected_prob.columns
1973
+ epsilon = 1e-6
1974
+
1975
+ # Normalize and fix Nan
1976
+ edge_percentage_norm = np.log10(df_cell_interaction_prob.values / (df_expected_prob.values+epsilon) + epsilon)
1977
+
1978
+ # if observed/expected = 0, then log odds ratio will have log10(epsilon)
1979
+ # no observed means interaction cannot be determined, does not mean strong negative interaction
1980
+ edge_percentage_norm[edge_percentage_norm == np.log10(epsilon)] = 0
1981
+
1982
+ edge_perc_remapped = pd.DataFrame(edge_percentage_norm, index=marker_all, columns=marker_all)
1983
+ edge_perc_remapped["roi_id"] = roi_keys
1984
+ marker_roi_list.append(edge_perc_remapped)
1985
+
1986
+ # concatenate all pt df
1987
+ edge_percentage_cohort = pd.concat(marker_roi_list, axis=0)
1988
+ edge_percentage_cohort = edge_percentage_cohort.reset_index(names='marker')
1989
+
1990
+ # cohort specific: 0 was used to indicate not observed, but average over will skew the df
1991
+ edge_percentage_cohort = edge_percentage_cohort.replace(0, np.nan)
1992
+
1993
+ return edge_percentage_cohort, marker_all
cytof/utils.py CHANGED
@@ -358,7 +358,7 @@ def check_feature_distribution(feature_summary_df, features):
358
  # return None
359
  # return fig
360
 
361
- def visualize_scatter(data, communities, n_community, title, figsize=(5,5), savename=None, show=False, ax=None):
362
  """
363
  data = data to visualize (N, 2)
364
  communities = group indices correspond to each sample in data (N, 1) or (N, )
@@ -372,10 +372,15 @@ def visualize_scatter(data, communities, n_community, title, figsize=(5,5), save
372
  else:
373
  fig = None
374
  ax.set_title(title)
375
- sns.scatterplot(x=data[:,0], y=data[:,1], hue=communities, palette='tab20',
376
- hue_order=np.arange(n_community), ax=ax)
377
- # legend=legend,
378
- # hue_order=np.arange(n_community))
 
 
 
 
 
379
 
380
  ax.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
381
  # plt.axis('tight')
@@ -387,7 +392,7 @@ def visualize_scatter(data, communities, n_community, title, figsize=(5,5), save
387
  plt.show()
388
  if clos:
389
  plt.close('all')
390
- return fig
391
 
392
  def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savename=None, show=False, ax=None):
393
  clos = not show and ax is None
@@ -403,8 +408,8 @@ def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savenam
403
  yticklabels=group_ids,
404
  ax=ax
405
  )
406
- ax.set_xlabel("Markers")
407
- ax.set_ylabel("Phenograph clusters")
408
  ax.set_title("normalized expression - {}".format(title))
409
  ax.xaxis.set_tick_params(labelsize=8)
410
  if savename is not None:
@@ -414,7 +419,7 @@ def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savenam
414
  plt.show()
415
  if clos:
416
  plt.close('all')
417
- return fig
418
 
419
  def _get_thresholds(df_feature: pd.DataFrame,
420
  features: List[str],
 
358
  # return None
359
  # return fig
360
 
361
+ def visualize_scatter(data, communities, n_community, title, scatter_dot_size, figsize=(5,5), savename=None, show=False, ax=None):
362
  """
363
  data = data to visualize (N, 2)
364
  communities = group indices correspond to each sample in data (N, 1) or (N, )
 
372
  else:
373
  fig = None
374
  ax.set_title(title)
375
+ sns.scatterplot(x=data[:,0],
376
+ y=data[:,1],
377
+ hue=communities,
378
+ palette='tab20',
379
+ s=scatter_dot_size,
380
+ alpha=0.9,
381
+ linewidth=0,
382
+ hue_order=np.arange(n_community), ax=ax
383
+ )
384
 
385
  ax.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
386
  # plt.axis('tight')
 
392
  plt.show()
393
  if clos:
394
  plt.close('all')
395
+ return fig, ax
396
 
397
  def visualize_expression(data, markers, group_ids, title, figsize=(5,5), savename=None, show=False, ax=None):
398
  clos = not show and ax is None
 
408
  yticklabels=group_ids,
409
  ax=ax
410
  )
411
+ # ax.set_xlabel("Markers")
412
+ ax.set_ylabel("PhenoGraph clusters")
413
  ax.set_title("normalized expression - {}".format(title))
414
  ax.xaxis.set_tick_params(labelsize=8)
415
  if savename is not None:
 
419
  plt.show()
420
  if clos:
421
  plt.close('all')
422
+ return fig, ax
423
 
424
  def _get_thresholds(df_feature: pd.DataFrame,
425
  features: List[str],