Michael commited on
Commit
0b31237
·
1 Parent(s): b8dcc35

add methods

Browse files
methods/__init__.py ADDED
File without changes
methods/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
methods/__pycache__/gdc_api_calls.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
methods/__pycache__/utilities.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
methods/gdc_api_calls.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import ast
3
+ import glob
4
+ import json
5
+ import os
6
+ from functools import reduce
7
+ from pathlib import Path
8
+
9
+ import pandas as pd
10
+ import requests
11
+
12
+ proj_root = Path(__file__).resolve().parent.parent
13
+
14
+
15
+ # match "lymphoid leukemia" in query to "lymphoid leukemias" in GDC disease_type
16
+ # load project_mappings
17
+ # the function to create this tsv file is a one-time run, found as one of the api functions below
18
+ project_mappings = pd.read_csv(
19
+ os.path.join(proj_root, "csvs", "gdc_projects.tsv"),
20
+ sep="\t", index_col=0, names=["project", "desc"]
21
+ )
22
+ project_mappings["desc"] = project_mappings["desc"].apply(ast.literal_eval)
23
+ project_mappings = project_mappings["desc"].to_dict()
24
+
25
+
26
+ def get_gene_mutation_data(start, stop, step):
27
+ # cannot query the entire thing at once, need to do it in parts
28
+ for mini_stop in range(start, stop, step):
29
+ if mini_stop != 0:
30
+ # curl_cmd = "https://api.gdc.cancer.gov/ssms?fields=gene_aa_change&from={}&size={}".format(start, mini_stop)
31
+ # print('curl cmd {}'.format(curl_cmd))
32
+ response = requests.get(curl_cmd)
33
+ out_file = "_".join([str(start), str(mini_stop), "gene.mutation.txt"])
34
+ with open(out_file, "w") as response_out:
35
+ response_out.write(response.text)
36
+ start = mini_stop
37
+ # final curl_cmd
38
+ curl_cmd = (
39
+ "https://api.gdc.cancer.gov/ssms?fields=gene_aa_change&from={}&size={}".format(
40
+ start, stop
41
+ )
42
+ )
43
+ # print('curl cmd {}'.format(curl_cmd))
44
+ response = requests.get(curl_cmd)
45
+ out_file = "_".join([str(start), str(stop), "gene.mutation.txt"])
46
+ with open(out_file, "w") as response_out:
47
+ response_out.write(response.text)
48
+
49
+
50
+ def process_gene_mutation_data():
51
+ gdc_genes = {}
52
+ gene_mutation_data_files = glob.glob("*gene.mutation.txt")
53
+ # print('gene_mutation_data_files {}'.format(gene_mutation_data_files))
54
+ for f in gene_mutation_data_files:
55
+ # print('processing file {}'.format(f))
56
+ with open(f, "r") as f_in:
57
+ data = json.load(f_in)
58
+ for item in data["data"]["hits"]:
59
+ for gene_aa_change in item["gene_aa_change"]:
60
+ gene, mutation = gene_aa_change.split(" ")
61
+ if not gene in gdc_genes:
62
+ gdc_genes[gene] = []
63
+ if not mutation in gdc_genes[gene]:
64
+ gdc_genes[gene].append(mutation)
65
+
66
+ with open("gdc_genes_mutations.json", "w") as f_out:
67
+ json.dump(gdc_genes, f_out, indent=4)
68
+
69
+
70
+ # this function creates the project mappings tsv file
71
+ # only to be run once
72
+ def get_gdc_project_ids(start, stop):
73
+ project_mappings = {}
74
+ curl_cmd = "https://api.gdc.cancer.gov/projects?fields=project_id,disease_type,primary_site,name&from={}&size={}".format(
75
+ start, stop
76
+ )
77
+ # print('curl cmd {}'.format(curl_cmd))
78
+ out_file = "gdc_projects.tsv"
79
+ try:
80
+ response = requests.get(curl_cmd)
81
+ # print('status code {}'.format(response.status_code))
82
+ with open(out_file, "w") as response_out:
83
+ for item in response.json()["data"]["hits"]:
84
+ disease_type_and_name = item["disease_type"] + [item["name"]]
85
+ line = f"{item['project_id']}\t{disease_type_and_name}\n"
86
+ response_out.write(line)
87
+ project_mappings[item["project_id"]] = disease_type_and_name
88
+ # print('project_mappings {}'.format(project_mappings))
89
+ except Exception as e:
90
+ print("unable to execute GDC API request {}".format(str(e)))
91
+ return project_mappings
92
+
93
+
94
+ def get_ssm_id(gene, mutation):
95
+ ssm_id_endpt = "https://api.gdc.cancer.gov/ssms"
96
+ fields = ["mutation_type"]
97
+ fields = ",".join(fields)
98
+ expand = ["consequence.transcript"]
99
+ filters = {
100
+ "op": "=",
101
+ "content": {"field": "ssms.gene_aa_change", "value": "[gene][mutation]"},
102
+ }
103
+ filters["content"]["value"] = gene + " " + mutation
104
+ # print('filters {}'.format(filters))
105
+ params = {
106
+ "filters": json.dumps(filters),
107
+ "fields": fields,
108
+ "expand": expand,
109
+ "size": 10,
110
+ }
111
+ try:
112
+ response = requests.get(ssm_id_endpt, params=params)
113
+ response_json = json.loads(response.content)
114
+ ssm_id = response_json["data"]["hits"][0]["id"]
115
+ except Exception as e:
116
+ print("unable to execute GDC API request {}".format(str(e)))
117
+ ssm_id = None
118
+ return ssm_id
119
+
120
+
121
+ def get_ssm_counts(ssm_id):
122
+ # get project level counts of ssm
123
+ ssm_counts_by_project = {}
124
+
125
+ ssm_occurrences_endpt = "https://api.gdc.cancer.gov/ssm_occurrences"
126
+ fields = ["case.project.project_id", "case.case_id"]
127
+ fields = ",".join(fields)
128
+ filters = {"op": "=", "content": {"field": "ssm.ssm_id", "value": ssm_id}}
129
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 1000}
130
+ try:
131
+ response = requests.get(ssm_occurrences_endpt, params=params)
132
+ ssm_counts = json.loads(response.content)
133
+ for item in ssm_counts["data"]["hits"]:
134
+ project_name = item["case"]["project"]["project_id"]
135
+ case_id_list = "case_id_list"
136
+ if not project_name in ssm_counts_by_project:
137
+ ssm_counts_by_project[project_name] = {}
138
+ ssm_counts_by_project[project_name][case_id_list] = []
139
+ ssm_counts_by_project[project_name][case_id_list].append(
140
+ item["case"]["case_id"]
141
+ )
142
+ ssm_counts_by_project[project_name]["ssm_counts"] = (
143
+ ssm_counts_by_project[project_name]["ssm_counts"] + 1
144
+ if "ssm_counts" in ssm_counts_by_project[project_name]
145
+ else 1
146
+ )
147
+ except Exception as e:
148
+ print("unable to execute GDC API request {}".format(str(e)))
149
+ return ssm_counts_by_project
150
+
151
+
152
+ def get_available_cnv_data_for_project(project):
153
+ case_ssm_endpt = "https://api.gdc.cancer.gov/case_ssms"
154
+ fields = ["project.project_id", "available_variation_data"]
155
+ fields = ",".join(fields)
156
+ filters = {
157
+ "op": "and",
158
+ "content": [
159
+ {
160
+ "op": "in",
161
+ "content": {"field": "available_variation_data", "value": "cnv"},
162
+ },
163
+ {"op": "=", "content": {"field": "project.project_id", "value": project}},
164
+ ],
165
+ }
166
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 1000}
167
+ try:
168
+ response = requests.get(case_ssm_endpt, params=params)
169
+ response_json = json.loads(response.content)
170
+ total_case_count = response_json["data"]["pagination"]["total"]
171
+ except Exception as e:
172
+ print("unable to execute GDC API request {}".format(str(e)))
173
+ total_case_count = 0
174
+ # print('total case count {}'.format(total_case_count))
175
+ return total_case_count
176
+
177
+
178
+ def get_available_ssm_data_for_project(project):
179
+ case_ssm_endpt = "https://api.gdc.cancer.gov/case_ssms"
180
+ fields = ["project.project_id", "available_variation_data"]
181
+ fields = ",".join(fields)
182
+
183
+ filters = {
184
+ "op": "and",
185
+ "content": [
186
+ {
187
+ "op": "in",
188
+ "content": {"field": "available_variation_data", "value": "ssm"},
189
+ },
190
+ {"op": "=", "content": {"field": "project.project_id", "value": project}},
191
+ ],
192
+ }
193
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 1000}
194
+ try:
195
+ response = requests.get(case_ssm_endpt, params=params)
196
+ response_json = json.loads(response.content)
197
+ total_case_count = response_json["data"]["pagination"]["total"]
198
+ except Exception as e:
199
+ print("unable to execute GDC API request {}".format(str(e)))
200
+ return total_case_count
201
+
202
+
203
+ def get_top_mutated_genes_by_project(cancer_entities, top_k):
204
+ # need an AI way of recognizing top k from query, here using 10 as default
205
+ top_mutated_genes_by_project = {}
206
+ # if cancer_entities is empty, initialize some entities
207
+ if not cancer_entities:
208
+ cancer_entities = list(project_mappings.keys())
209
+
210
+ for ce in cancer_entities:
211
+ endpt = "https://api.gdc.cancer.gov/analysis/top_mutated_genes_by_project"
212
+
213
+ fields = ["gene_id", "symbol"]
214
+ fields = ",".join(fields)
215
+
216
+ filters = {
217
+ "op": "and",
218
+ "content": [
219
+ {
220
+ "op": "in",
221
+ "content": {"field": "case.project.project_id", "value": [ce]},
222
+ }
223
+ ],
224
+ }
225
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 1000}
226
+ try:
227
+ response = requests.get(endpt, params=params)
228
+ response_json = json.loads(response.content)
229
+ top_mutated_genes_by_project[ce] = response_json["data"]["hits"][:top_k]
230
+ except Exception as e:
231
+ print("unable to execute GDC API request {}".format(str(e)))
232
+ return top_mutated_genes_by_project
233
+
234
+
235
+ def return_joint_single_cnv_frequency(cnv, cnv_change, cnv_change_5_category):
236
+ result_text = []
237
+ # set category for heterozygous del
238
+ if not cnv_change_5_category:
239
+ if cnv_change == "Loss":
240
+ cnv_change_5_category = "Heterozygous Deletion"
241
+ # print('formatting results {}'.format(cnv_change_5_category))
242
+ cnv_freq = {}
243
+ for ce, v in cnv.items():
244
+ cnv_freq[ce] = {}
245
+ genes = list(v.keys())
246
+ # print('ce, genes {} {}'.format(ce, genes))
247
+ total_number_of_cases_with_cnv_data = get_available_cnv_data_for_project(ce)
248
+ # skip if total number of cnv cases from API is 0
249
+ if not total_number_of_cases_with_cnv_data:
250
+ continue
251
+
252
+ if len(genes) > 1:
253
+ cases_with_cnvs = [set(cnv[ce][g]["case_id_list"]) for g in genes]
254
+ shared_cases = list(reduce(lambda x, y: x & y, cases_with_cnvs))
255
+ # print('shared_cases {}'.format(shared_cases))
256
+ joint_frequency = round(
257
+ (len(shared_cases) / total_number_of_cases_with_cnv_data) * 100, 2
258
+ )
259
+ result_text.append(
260
+ "joint frequency in {} is {}%".format(ce, joint_frequency)
261
+ )
262
+ else:
263
+ joint_frequency = 0
264
+ for k2, v2 in v.items():
265
+ result_text.append(
266
+ "The frequency of {} {} in {} is {}%".format(
267
+ k2, cnv_change_5_category, ce, v2["frequency"]
268
+ )
269
+ )
270
+ return result_text
271
+
272
+
273
+ def get_cnv_filter_with_cnv_change_category(cnv_change, ce, ge, cnv_change_5_category):
274
+
275
+ filter = {
276
+ "op": "and",
277
+ "content": [
278
+ {"op": "in", "content": {"field": "cnv.cnv_change", "value": [cnv_change]}},
279
+ {
280
+ "op": "in",
281
+ "content": {
282
+ "field": "cnv.cnv_change_5_category",
283
+ "value": [cnv_change_5_category],
284
+ },
285
+ },
286
+ {
287
+ "op": "=",
288
+ "content": {"field": "cnv.consequence.gene.symbol", "value": ge},
289
+ },
290
+ {"op": "=", "content": {"field": "case.project.project_id", "value": ce}},
291
+ ],
292
+ }
293
+ return filter
294
+
295
+
296
+ def get_freq_cnv_loss_or_gain(gene_entities, cancer_entities, query, cnv_and_ssm_flag):
297
+ cnv = {}
298
+ lc_query = query.lower()
299
+ # need to figure out how to get deletion and gain
300
+ # V1 is only co-deletion, or co-gain
301
+ loss_terms = ["loss", "loh", "deletion", "co-deletion", "lost", "LOH"]
302
+ if any(term in lc_query for term in loss_terms):
303
+ cnv_change = "Loss"
304
+ if "homozygous" in lc_query:
305
+ cnv_change_5_category = "Homozygous Deletion"
306
+ else:
307
+ cnv_change_5_category = "Loss"
308
+ else:
309
+ cnv_change = "Gain"
310
+ if "amplification" in lc_query:
311
+ cnv_change_5_category = "Amplification"
312
+ else:
313
+ cnv_change_5_category = "Gain"
314
+
315
+ if not cancer_entities:
316
+ cancer_entities = list(project_mappings.keys())
317
+
318
+ # print('cnv change, cnv change 5 category in query {} {}'.format(
319
+ # cnv_change, cnv_change_5_category))
320
+
321
+ for ce in cancer_entities:
322
+ for ge in gene_entities:
323
+ # print('processing {}, {}'.format(ce, ge))
324
+ endpt = "https://api.gdc.cancer.gov/cnv_occurrences"
325
+ fields = [
326
+ "cnv.chromosome",
327
+ "cnv.cnv_change",
328
+ "cnv.cnv_change_5_category" "cnv.consequence.gene.symbol",
329
+ "case.case_id",
330
+ "case.project.project_id",
331
+ ]
332
+ fields = ",".join(fields)
333
+ filters = get_cnv_filter_with_cnv_change_category(
334
+ cnv_change, ce, ge, cnv_change_5_category
335
+ )
336
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 1000}
337
+ try:
338
+ # print('filters {}'.format(json.dumps(filters)))
339
+ # skip if response not successful
340
+ response = requests.get(endpt, params=params)
341
+ response_json = json.loads(response.content)
342
+ except Exception as e:
343
+ print("exception: {}".format(str(e)))
344
+ continue
345
+
346
+ total_number_of_cases_with_cnv_data = get_available_cnv_data_for_project(ce)
347
+ # skip if cannot obtain total # of cnv cases from API
348
+ if not total_number_of_cases_with_cnv_data:
349
+ continue
350
+
351
+ if not ce in cnv:
352
+ cnv[ce] = {}
353
+ if not ge in cnv[ce]:
354
+ cnv[ce][ge] = {}
355
+
356
+ case_id_list = []
357
+ for item in response_json["data"]["hits"]:
358
+ if item["case"]["case_id"]:
359
+ case_id_list.append(item["case"]["case_id"])
360
+ number_of_cases_with_cnv_change = len(case_id_list)
361
+ # print('number of cases with cnv change {}'.format(number_of_cases_with_cnv_change))
362
+ freq = number_of_cases_with_cnv_change / total_number_of_cases_with_cnv_data
363
+ cnv[ce][ge]["case_id_list"] = case_id_list
364
+ cnv[ce][ge]["frequency"] = round(freq * 100, 2)
365
+
366
+ # print('debug: cnv {}'.format(cnv))
367
+ if cnv_and_ssm_flag:
368
+ return cnv
369
+ else:
370
+ result_text = return_joint_single_cnv_frequency(
371
+ cnv, cnv_change, cnv_change_5_category
372
+ )
373
+ cancer_entities = list(cnv.keys())
374
+ return result_text, cancer_entities
375
+
376
+
377
+ def get_msi_frequency(cancer_entities):
378
+ msi_h_frequency = {}
379
+ result_text = []
380
+ # init some starting cancer entities if none
381
+ if not cancer_entities:
382
+ cancer_entities = list(project_mappings.keys())
383
+ for ce in cancer_entities:
384
+ endpt = "https://api.gdc.cancer.gov/files"
385
+ fields = [
386
+ "cases.project.project_id",
387
+ "msi_score",
388
+ "msi_status",
389
+ "experimental_strategy",
390
+ ]
391
+ fields = ",".join(fields)
392
+
393
+ filters = {
394
+ "op": "and",
395
+ "content": [
396
+ {"op": "=", "content": {"field": "data_format", "value": "BAM"}},
397
+ {
398
+ "op": "in",
399
+ "content": {
400
+ "field": "experimental_strategy",
401
+ "value": ["WXS", "WGS"],
402
+ },
403
+ },
404
+ {
405
+ "op": "in",
406
+ "content": {"field": "cases.project.project_id", "value": [ce]},
407
+ },
408
+ ],
409
+ }
410
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 10000}
411
+ try:
412
+ response = requests.get(endpt, params=params)
413
+ response_json = json.loads(response.content)
414
+
415
+ msi_results = []
416
+ for item in response_json["data"]["hits"]:
417
+ # only score tumors where MSI status is computed for frequency
418
+ if "msi_status" in item:
419
+ # exclude None
420
+ if item['msi_status']:
421
+ msi_results.append(item["msi_status"])
422
+ freq = msi_results.count("MSI") / len(msi_results)
423
+ msi_h_frequency[ce] = {"frequency": round(freq * 100, 2)}
424
+ result_text.append(
425
+ "The frequency of MSI in {} is {}%".format(
426
+ ce, msi_h_frequency[ce]["frequency"]
427
+ )
428
+ )
429
+ except Exception as e:
430
+ print("unable to execute GDC API request {}".format(str(e)))
431
+ ce_api_success = list(msi_h_frequency.keys())
432
+ return result_text, ce_api_success
433
+
434
+
435
+ def get_ensembl_gene_ids(gene_entities):
436
+ ensembl_gene_ids = []
437
+ for ge in gene_entities:
438
+ endpt = "https://api.gdc.cancer.gov/genes"
439
+ fields = ["gene_id"]
440
+ fields = ",".join(fields)
441
+ filters = {
442
+ "op": "and",
443
+ "content": [{"op": "=", "content": {"field": "symbol", "value": ge}}],
444
+ }
445
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 100}
446
+ try:
447
+ response = requests.get(endpt, params=params)
448
+ response_json = json.loads(response.content)
449
+ ensembl_gene_ids.append(response_json["data"]["hits"][0]["gene_id"])
450
+ except Exception as e:
451
+ print("unable to execute GDC API request {}".format(str(e)))
452
+ return ensembl_gene_ids
453
+
454
+
455
+ def get_total_variation_data_for_project(project):
456
+ case_ssm_endpt = "https://api.gdc.cancer.gov/case_ssms"
457
+ fields = ["project.project_id", "available_variation_data"]
458
+ fields = ",".join(fields)
459
+
460
+ filters = {
461
+ "op": "and",
462
+ "content": [
463
+ {
464
+ "op": "in",
465
+ "content": {
466
+ "field": "available_variation_data",
467
+ "value": ["ssm", "cnv"],
468
+ },
469
+ },
470
+ {"op": "=", "content": {"field": "project.project_id", "value": project}},
471
+ ],
472
+ }
473
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 1000}
474
+ try:
475
+ response = requests.get(case_ssm_endpt, params=params)
476
+ response_json = json.loads(response.content)
477
+ total_case_count = response_json["data"]["pagination"]["total"]
478
+ except Exception as e:
479
+ print("unable to execute GDC API request {}".format(str(e)))
480
+ total_case_count = 0
481
+
482
+ return total_case_count
483
+
484
+
485
+ def get_cases_with_ssms_in_a_gene(project, gene_name):
486
+
487
+ result = {}
488
+ endpt = "https://api.gdc.cancer.gov/ssm_occurrences"
489
+ fields = ["case.case_id"]
490
+ fields = ",".join(fields)
491
+
492
+ filters = {
493
+ "op": "and",
494
+ "content": [
495
+ {
496
+ "op": "=",
497
+ "content": {"field": "case.project.project_id", "value": project},
498
+ },
499
+ {
500
+ "op": "in",
501
+ "content": {
502
+ "field": "ssm.consequence.transcript.gene.symbol",
503
+ "value": gene_name,
504
+ },
505
+ },
506
+ ],
507
+ }
508
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 1000}
509
+ try:
510
+ response = requests.get(endpt, params=params)
511
+ response_json = json.loads(response.content)
512
+ case_id_list = []
513
+ for item in response_json["data"]["hits"]:
514
+ if item["case"]["case_id"]:
515
+ case_id_list.append(item["case"]["case_id"])
516
+ result["case_id_list"] = list(set(case_id_list))
517
+ except Exception as e:
518
+ print("unable to execute GDC API request {}".format(str(e)))
519
+ return result
520
+
521
+
522
+ def run_cnv_ssm_api(decompose_result, cancer_entities, query):
523
+ """
524
+ decompose_result['cnv_and_ssm'] = True
525
+ decompose_result['cnv_gene'] = cnv_gene.split(':')[1]
526
+ decompose_result['mut_gene'] = mut_gene.split(':')[1]
527
+ decompose_result['cnv_change_type'] = match_term
528
+ """
529
+ gene_entities = []
530
+ cases_with_ssm_and_cnvs = []
531
+ result = []
532
+ gene_entities.append(decompose_result["cnv_gene"])
533
+ cnv_result = get_freq_cnv_loss_or_gain(
534
+ gene_entities, cancer_entities, query, cnv_and_ssm_flag=True
535
+ )
536
+
537
+ for ce in cancer_entities:
538
+
539
+ try:
540
+ # get_cases_with_ssms_in_a_gene returns the number of cases with ssms
541
+ ssm_result = get_cases_with_ssms_in_a_gene(
542
+ project=ce, gene_name=decompose_result["mut_gene"]
543
+ )
544
+ # calcuate overlap of cases and return freq
545
+ cases_with_ssm_and_cnvs = [
546
+ set(cnv_result[ce][decompose_result["cnv_gene"]]["case_id_list"]),
547
+ set(ssm_result["case_id_list"]),
548
+ ]
549
+ shared_cases = list(reduce(lambda x, y: x & y, cases_with_ssm_and_cnvs))
550
+ total_case_count = get_total_variation_data_for_project(project=ce)
551
+ # print('shared_cases, len {} {}'.format(shared_cases, len(shared_cases)))
552
+ # print('total_case_count {}'.format(total_case_count))
553
+ freq = round((len(shared_cases) / total_case_count) * 100, 2)
554
+ joint_freq = "The joint frequency in {} is {}%".format(ce, freq)
555
+ except Exception as e:
556
+ joint_freq = "joint freq in {} is not available".format(ce)
557
+ result.append(joint_freq)
558
+ return result, cancer_entities
559
+
560
+
561
+ def get_top_cases_counts_by_gene(gene_entities, cancer_entities):
562
+ top_cases_counts_by_gene = {}
563
+ result = []
564
+ emsembl_gene_ids = get_ensembl_gene_ids(gene_entities)
565
+ if not cancer_entities:
566
+ cancer_entities = list(project_mappings.keys())
567
+ for ce in cancer_entities:
568
+ top_cases_counts_by_gene[ce] = {}
569
+ # note this gives you ssm + cnv
570
+ endpt = "https://api.gdc.cancer.gov/analysis/top_cases_counts_by_genes?gene_ids={}".format(
571
+ ",".join(emsembl_gene_ids)
572
+ )
573
+ response = requests.get(endpt)
574
+ response_json = json.loads(response.content)
575
+ try:
576
+ for item in response_json["aggregations"]["projects"]["buckets"]:
577
+ if item["key"] == ce:
578
+ cases_with_mutations = item["doc_count"]
579
+ # total_case_count = get_available_ssm_data_for_project(ce)
580
+ total_case_count = get_total_variation_data_for_project(project=ce)
581
+ cases_without_mutations = total_case_count - cases_with_mutations
582
+ top_cases_counts_by_gene[ce]["cases_with_mutations"] = cases_with_mutations
583
+ top_cases_counts_by_gene[ce][
584
+ "cases_without_mutations"
585
+ ] = cases_without_mutations
586
+ top_cases_counts_by_gene[ce]["total_case_count"] = total_case_count
587
+ freq = cases_with_mutations / total_case_count
588
+ top_cases_counts_by_gene[ce]["frequency"] = round(freq * 100, 2)
589
+ result.append(
590
+ "The frequency of cases with mutations in {} is {}%".format(
591
+ ce, top_cases_counts_by_gene[ce]["frequency"]
592
+ )
593
+ )
594
+ except Exception as e:
595
+ result.append("frequency unavailable from API for {}".format(ce))
596
+ cancer_entities = list(top_cases_counts_by_gene.keys())
597
+ return result, cancer_entities
598
+
599
+
600
+ def get_project_summary(cancer_entities):
601
+ project_summary = {}
602
+ for ce in cancer_entities:
603
+ endpt = "https://api.gdc.cancer.gov/projects/{}?expand=summary,summary.experimental_strategies,summary.data_categories".format(
604
+ ce
605
+ )
606
+ response = requests.get(endpt)
607
+ response_json = json.loads(response.content)
608
+ project_summary[ce]["project_summary"] = response_json["data"]
609
+ return project_summary
610
+
611
+
612
+ def map_cancer_entities_to_project(initial_cancer_entities, project_mappings):
613
+ project_match = {}
614
+ for ce in initial_cancer_entities:
615
+ # cancer_wild_card = '*' + ce
616
+ endpoint = "https://api.gdc.cancer.gov/projects"
617
+ fields = ["project_id", "disease_type", "name"]
618
+ fields = ",".join(fields)
619
+
620
+ filters = {"op": "=", "content": {"field": "name", "value": [ce]}}
621
+ params = {"filters": json.dumps(filters), "fields": fields, "size": 10000}
622
+ try:
623
+ response = requests.get(endpoint, params=params)
624
+ response_json = json.loads(response.content)
625
+ # print('response_json {}'.format(json.dumps(
626
+ # response_json, indent=4)))
627
+ for item in response_json["data"]["hits"]:
628
+ project_id = item["project_id"]
629
+ project_match[ce] = project_id
630
+ except Exception as e:
631
+ pass
632
+ # print('unable to return a match from projects endpt '
633
+ # 'perform further checks on project_mappings')
634
+ return project_match
methods/utilities.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # various utility functions employed by the pipeline
3
+ import json
4
+ import re
5
+ import time
6
+ from functools import reduce, wraps
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import spacy
11
+ import torch
12
+
13
+ from guidance.models import Transformers
14
+ from guidance import gen as guidance_gen
15
+
16
+ from huggingface_hub import HfFolder, hf_hub_download
17
+ from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification
18
+
19
+
20
+ from methods import gdc_api_calls
21
+
22
+
23
+ def load_llama_llm(AUTH_TOKEN):
24
+ # hugging face model
25
+ # https://huggingface.co/blog/llama32
26
+ model_id = "meta-llama/Llama-3.2-3B-Instruct"
27
+ tok = AutoTokenizer.from_pretrained(
28
+ model_id, trust_remote_code=True,
29
+ token=AUTH_TOKEN
30
+ )
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ model_id,
33
+ torch_dtype=torch.float16,
34
+ trust_remote_code=True,
35
+ token=AUTH_TOKEN
36
+ )
37
+ model = model.to('cuda')
38
+ model = model.eval()
39
+
40
+ return model, tok
41
+
42
+
43
+ def load_gdc_genes_mutations_hf(AUTH_TOKEN):
44
+ dataset_id = 'uc-ctds/GDC-QAG-genes-mutations'
45
+ filename = 'gdc_genes_mutations.json'
46
+ json_path = hf_hub_download(
47
+ repo_id=dataset_id,
48
+ filename=filename,
49
+ repo_type="dataset",
50
+ token=AUTH_TOKEN
51
+ )
52
+ with open(json_path, 'r') as f:
53
+ gdc_genes_mutations = json.load(f)
54
+ return gdc_genes_mutations
55
+
56
+
57
+
58
+ def load_intent_model_hf(AUTH_TOKEN):
59
+ model_id = 'uc-ctds/query_intent'
60
+ tok = AutoTokenizer.from_pretrained(
61
+ model_id, trust_remote_code=True,
62
+ token=AUTH_TOKEN
63
+ )
64
+ model = BertForSequenceClassification.from_pretrained(
65
+ model_id)
66
+ return model, tok
67
+
68
+
69
+
70
+ def infer_user_intent(query, intent_model, intent_tok):
71
+ # model, tokenizer = load_intent_model(intent_model_path)
72
+ intent_labels = {
73
+ "ssm_frequency": 0.0,
74
+ "msi_h_frequency": 1.0,
75
+ "freq_cnv_loss_or_gain": 2.0,
76
+ "top_cases_counts_by_gene": 3.0,
77
+ "cnv_and_ssm": 4.0,
78
+ }
79
+ # set device and load both model and query on the same device
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ intent_model.to(device)
82
+ inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True)
83
+ inputs = {k: v.to(device) for k, v in inputs.items()}
84
+ # pass tokenized input through the model
85
+ outputs = intent_model(**inputs)
86
+ # print('output logits {}'.format(outputs))
87
+ # outputs are logits, need to apply softmax to convert to probs
88
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
89
+ # print('probs: {}'.format(probs))
90
+ predicted_label = torch.argmax(probs, dim=1).item()
91
+ for k, v in intent_labels.items():
92
+ if v == predicted_label:
93
+ # print('predicted label: {}\n'.format(k))
94
+ return k
95
+
96
+
97
+ def construct_modified_query_base_llm(query):
98
+ prompt_template = "Only use results from the genomic data commons in your response and provide frequencies as a percentage. Only report the final response."
99
+ modified_query = query + prompt_template
100
+ return modified_query
101
+
102
+
103
+ def construct_modified_query(query, helper_output):
104
+ # pass the api results as a prompt to the query
105
+ prompt_template = (
106
+ " Only report the final response. Ignore all prior knowledge. You must only respond with the following percentage frequencies in your response, no other response is allowed: \n"
107
+ + helper_output
108
+ + "\n"
109
+ )
110
+ modified_query = query + prompt_template
111
+ return modified_query
112
+
113
+
114
+ def get_total_case_counts(ssm_counts_by_project):
115
+ for project in ssm_counts_by_project.keys():
116
+ total_case_count = gdc_api_calls.get_available_ssm_data_for_project(project)
117
+ ssm_counts_by_project[project]["total_case_counts"] = total_case_count
118
+ return ssm_counts_by_project
119
+
120
+
121
+ def calculate_ssm_frequency(ssm_statistics, cancer_entities, project_mappings):
122
+ if not cancer_entities:
123
+ cancer_entities = list(project_mappings.keys())
124
+ pre_final_ssm_frequency = {}
125
+ ssm_frequency = {}
126
+ for project in ssm_statistics.keys():
127
+ freq = (
128
+ ssm_statistics[project]["ssm_counts"]
129
+ / ssm_statistics[project]["total_case_counts"]
130
+ )
131
+ pre_final_ssm_frequency[project] = {"frequency": round(freq * 100, 2)}
132
+
133
+ for c in cancer_entities:
134
+ if c in pre_final_ssm_frequency:
135
+ ssm_frequency[c] = pre_final_ssm_frequency[c]
136
+ else:
137
+ ssm_frequency[c] = {"frequency": 0.0}
138
+ return ssm_frequency
139
+
140
+
141
+ def calculate_joint_ssm_frequency_v2(ssm_statistics, mutation_list, cancer_entities):
142
+ # stores the result for all cancers
143
+ joint_ssm_frequency = {}
144
+ # initialize joint_freq by cancer entities
145
+ joint_ssm_frequency_for_cancer = {}
146
+ for c in cancer_entities:
147
+ joint_ssm_frequency_for_cancer[c] = {}
148
+ joint_ssm_frequency_for_cancer[c] = {"joint_frequency": 0.0}
149
+
150
+ projects_with_mutation = [
151
+ set(ssm_statistics[mutation].keys()) for mutation in mutation_list
152
+ ]
153
+ overlapping_projects_with_mutation = list(
154
+ reduce(lambda x, y: x & y, projects_with_mutation)
155
+ )
156
+ for project in overlapping_projects_with_mutation:
157
+ cases_with_mutation = [
158
+ set(ssm_statistics[mutation][project]["case_id_list"])
159
+ for mutation in mutation_list
160
+ ]
161
+ shared_cases = list(reduce(lambda x, y: x & y, cases_with_mutation))
162
+ # print('shared cases, len shared cases {} {}'.format(shared_cases, len(shared_cases)))
163
+ if shared_cases:
164
+ if project not in joint_ssm_frequency:
165
+ joint_ssm_frequency[project] = {}
166
+ total_case_counts = gdc_api_calls.get_available_ssm_data_for_project(
167
+ project
168
+ )
169
+ joint_frequency = len(shared_cases) / total_case_counts
170
+ # print('shared_cases {}'.format(shared_cases))
171
+ # print('joint freq {}'.format(joint_frequency))
172
+ joint_ssm_frequency[project]["joint_frequency"] = round(
173
+ joint_frequency * 100, 2
174
+ )
175
+ # filter for specific cancer type and return
176
+ for c in cancer_entities:
177
+ if c in joint_ssm_frequency:
178
+ joint_ssm_frequency_for_cancer[c]["joint_frequency"] = joint_ssm_frequency[
179
+ c
180
+ ]["joint_frequency"]
181
+ return joint_ssm_frequency_for_cancer
182
+
183
+
184
+ def flatten_ssm_results_to_text(result, result_type):
185
+ result_text = []
186
+ if result_type == "joint_frequency":
187
+ for k, v in result.items():
188
+ if k == "joint_frequency":
189
+ for k2, v2 in v.items():
190
+ result_text.append(
191
+ "joint frequency in {} is {}%".format(k2, v2["joint_frequency"])
192
+ )
193
+ else:
194
+ for k, v in result.items():
195
+ if k != "joint_frequency":
196
+ for k2, v2 in v.items():
197
+ result_text.append(
198
+ "The frequency of {} in {} is {}%".format(
199
+ k, k2, v2["frequency"]
200
+ )
201
+ )
202
+ return result_text
203
+
204
+
205
+ def get_ssm_frequency(
206
+ gene_entities, mutation_entities, cancer_entities, project_mappings
207
+ ):
208
+ ssm_statistics = {}
209
+ mutation_list = []
210
+ result = {}
211
+ # to match the genes with mutations
212
+ if len(mutation_entities) > len(gene_entities):
213
+ gene_entities = gene_entities * len(mutation_entities)
214
+ # print('gene entities {}'.format(gene_entities))
215
+ for gene, mutation in zip(gene_entities, mutation_entities):
216
+ mutation_name = "_".join([gene, mutation])
217
+ # print('computing frequency of {}'.format(mutation_name))
218
+ mutation_list.append(mutation_name)
219
+ ssm_id = gdc_api_calls.get_ssm_id(gene, mutation)
220
+ ssm_counts_by_project = gdc_api_calls.get_ssm_counts(ssm_id)
221
+ ssm_statistics[mutation_name] = get_total_case_counts(ssm_counts_by_project)
222
+ # full_result for all cancer entities
223
+ # test code for generalizability to multiple cancer entities
224
+ # full_result format is {'project1': {'frequency': }, 'project2': {'frequency':}, 'projectn': {'frequency':}}
225
+ full_result = calculate_ssm_frequency(
226
+ ssm_statistics[mutation_name], cancer_entities, project_mappings
227
+ )
228
+ # result format:
229
+ """
230
+ {
231
+ 'gene_mutation': # e.g. JAK2_V617F
232
+ {
233
+ 'project1': {'frequency': },
234
+ 'project2': {'frequency':},
235
+ 'projectn': {'frequency':}
236
+ }
237
+ }
238
+ 'project1': {'frequency': }, 'project2': {'frequency':}
239
+ """
240
+ result[mutation_name] = {
241
+ k: v for k, v in full_result.items() if k in cancer_entities
242
+ }
243
+ # if no entity match to specific gdc projects, return all
244
+ if not result[mutation_name].values():
245
+ result[mutation_name] = full_result
246
+ # print('API result ssm freq {}'.format(result))
247
+ # final cancer entities
248
+ for k, v in result.items():
249
+ cancer_entities = list(v.keys())
250
+ # print('ssm freq cancer entities {}'.format(cancer_entities))
251
+ # print('mutation list {}'.format(mutation_list))
252
+ # only supporting for two mutations atm
253
+ if len(mutation_list) > 1:
254
+ # print('computing joint frequency')
255
+ result["joint_frequency"] = calculate_joint_ssm_frequency_v2(
256
+ ssm_statistics, mutation_list=mutation_list, cancer_entities=cancer_entities
257
+ )
258
+ result_text = flatten_ssm_results_to_text(result, result_type="joint_frequency")
259
+ else:
260
+ result["joint_frequency"] = 0
261
+ result_text = flatten_ssm_results_to_text(
262
+ result, result_type="single_frequency"
263
+ )
264
+ # print('result_text {}'.format(result_text))
265
+ return result_text, cancer_entities
266
+
267
+
268
+ def decompose_mutation_and_cnv(query, match_term, gdc_genes_mutations):
269
+ decompose_result = {}
270
+ genes = [g for g in query.split(" ") if g in gdc_genes_mutations.keys()]
271
+ # query must have cnv first, followed by mutation
272
+ cnv_gene_name, mut_gene_name = genes[0], genes[1]
273
+ # print('cnv_gene_name, mut_gene_name {} {}'.format(
274
+ # cnv_gene_name, mut_gene_name))
275
+ decompose_result["cnv_and_ssm"] = True
276
+ decompose_result["cnv_gene"] = cnv_gene_name
277
+ decompose_result["mut_gene"] = mut_gene_name
278
+ decompose_result["cnv_change_type"] = match_term
279
+ return decompose_result
280
+
281
+
282
+ def get_freq_of_cnv_and_ssms(
283
+ query, cancer_entities, gene_entities, gdc_genes_mutations
284
+ ):
285
+ lc_query = query.lower()
286
+ match_term = ""
287
+ cnv_terms = [
288
+ "amplification",
289
+ "deletion",
290
+ "loss",
291
+ "gain",
292
+ "homozygous deletion",
293
+ "heterozygous deletion",
294
+ ]
295
+ for term in cnv_terms:
296
+ if term in lc_query:
297
+ match_term = term
298
+ # print('match_term {}'.format(match_term))
299
+ if match_term:
300
+ decompose_result = decompose_mutation_and_cnv(
301
+ query, match_term, gdc_genes_mutations
302
+ )
303
+ # print('decompose result {}'.format(decompose_result))
304
+ result, cancer_entities = gdc_api_calls.run_cnv_ssm_api(
305
+ decompose_result, cancer_entities, query
306
+ )
307
+ # print('result {}'.format(result))
308
+ else:
309
+ # no specific match terms, return freq of cnvs + ssm
310
+ result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene(
311
+ gene_entities, cancer_entities
312
+ )
313
+ return result, cancer_entities
314
+
315
+
316
+ def return_initial_cancer_entities(query, model):
317
+ nlp = spacy.load(model)
318
+ doc = nlp(query)
319
+ result = doc.ents
320
+ initial_cancer_entities = [e.text for e in result if e.label_ == "DISEASE"]
321
+ return initial_cancer_entities
322
+
323
+
324
+ def infer_gene_entities_from_query(query, gdc_genes_mutations):
325
+ entities = []
326
+ # gene recognition with simple dict-based method
327
+ for g in gdc_genes_mutations.keys():
328
+ if (g in query) and (g in query.split(" ")):
329
+ entities.append(g)
330
+ return entities
331
+
332
+
333
+ def check_if_project_id_in_query(project_list, query):
334
+ # check if mention of project keys
335
+ # e.g. TCGA-BRCA in query
336
+ final_entities = [
337
+ potential_ce
338
+ for potential_ce in query.split(" ")
339
+ if potential_ce in project_list
340
+ ]
341
+ return final_entities
342
+
343
+
344
+ def proj_id_and_partial_match(query, project_mappings, initial_cancer_entities):
345
+ final_entities = []
346
+ if initial_cancer_entities:
347
+ # print('checking for full match between initial cancer entities and GDC project descriptions')
348
+ # check for match with project_mapping values
349
+ # e.g. match "ovarian serous cystadenocarcinoma" to TCGA-OV project
350
+ for ic in initial_cancer_entities:
351
+ for k, v in project_mappings.items():
352
+ for c in v:
353
+ if ic in c.lower():
354
+ # print('found!!! {} {}'.format(ic, c.lower()))
355
+ final_entities.append(k)
356
+ else:
357
+ # print('no initial cancer entities, check for full match between query terms and GDC project descriptions')
358
+ for term in query.lower().split(" "):
359
+ for k, v in project_mappings.items():
360
+ for c in v:
361
+ if term in c.lower():
362
+ # print('found!!! {} {}'.format(ic, c.lower()))
363
+ final_entities.append(k)
364
+ return list(set(final_entities))
365
+
366
+
367
+ def postprocess_cancer_entities(project_mappings, initial_cancer_entities, query):
368
+ # print('initial cancer entities {}'.format(initial_cancer_entities))
369
+ project_list = project_mappings.keys()
370
+ # print('check if GDC project-id mentioned in query')
371
+ final_entities = check_if_project_id_in_query(project_list, query)
372
+ if final_entities:
373
+ return final_entities
374
+ else:
375
+ if initial_cancer_entities:
376
+ # first query GDC projects endpt
377
+ # print('test 1 (w/ initial entities): querying GDC projects endpt for project_id')
378
+ gdc_project_match = gdc_api_calls.map_cancer_entities_to_project(
379
+ initial_cancer_entities, project_mappings
380
+ )
381
+ # print('mapped projects to ids {}'.format(gdc_project_match))
382
+ if gdc_project_match.values():
383
+ final_entities = list(gdc_project_match.values())
384
+ if not final_entities:
385
+ # print('test 2 (w/ initial entities): no result from GDC projects endpt, check for matches '
386
+ # 'between query terms and gdc project_mappings')
387
+ final_entities = proj_id_and_partial_match(
388
+ query, project_mappings, initial_cancer_entities
389
+ )
390
+ else:
391
+ # no initial_cancer_entities
392
+ # check project_mappings keys/values for matches with query terms
393
+ # print('test 3 (w/o initial entities): no result from GDC projects endpt, check for matches '
394
+ # 'between query terms and gdc project_mappings')
395
+ final_entities = proj_id_and_partial_match(
396
+ query, project_mappings, initial_cancer_entities
397
+ )
398
+ return final_entities
399
+
400
+
401
+ def infer_mutation_entities(gene_entities, query, gdc_genes_mutations):
402
+ mutation_entities = []
403
+ for g in gene_entities:
404
+ for m in gdc_genes_mutations[g]:
405
+ if m in query:
406
+ mutation_entities.append(m)
407
+ return mutation_entities
408
+
409
+
410
+ def postprocess_response(row):
411
+ value_changed = "no"
412
+ pattern = r".*?(\d*\.\d*)%.*?"
413
+ delta_final = np.nan
414
+ delta_prefinal = np.nan
415
+ generated_stat_final = np.nan
416
+
417
+ try:
418
+ helper_output = row["helper_output"]
419
+ except Exception as e:
420
+ # print('unable to generate helper output, returning nan')
421
+ return pd.Series(["np.nan"] * 8)
422
+
423
+ pre_final_response = row["pre_final_llama_with_helper_output"]
424
+ llama_base_output = row["llama_base_output"]
425
+
426
+ try:
427
+ llama_base_stat = float(re.search(pattern, llama_base_output).group(1))
428
+ except Exception as e:
429
+ # print('unable to extract llama base stat {}'.format(str(e)))
430
+ llama_base_stat = np.nan
431
+ try:
432
+ generated_stat_prefinal = float(re.search(pattern, pre_final_response).group(1))
433
+ except Exception as e:
434
+ # print('unable to extract generated stat {}'.format(str(e)))
435
+ generated_stat_prefinal = np.nan
436
+
437
+ try:
438
+ ground_truth_stat = float(re.search(pattern, helper_output).group(1))
439
+ except Exception as e:
440
+ # print('unable to extract ground truth stat {}'.format(str(e)))
441
+ ground_truth_stat = np.nan
442
+
443
+ try:
444
+ delta_llama = llama_base_stat - ground_truth_stat
445
+ except Exception as e:
446
+ # print('unable to calculate delta_llama {}'.format(str(e)))
447
+ delta_llama = np.nan
448
+
449
+ if not np.isnan(generated_stat_prefinal) and not np.isnan(ground_truth_stat):
450
+ delta_prefinal = generated_stat_prefinal - ground_truth_stat
451
+ if delta_prefinal != 0.0:
452
+ final_response = "The final answer is: {}%".format(ground_truth_stat)
453
+ value_changed = "yes"
454
+ else:
455
+ final_response = pre_final_response
456
+ generated_stat_final = float(re.search(pattern, final_response).group(1))
457
+ delta_final = generated_stat_final - ground_truth_stat
458
+ else:
459
+ final_response = "unable to postprocess, check generated or truth stat"
460
+ value_changed = "na"
461
+ """
462
+ print('check if all values are populated:\n')
463
+ print('delta_llama {}'.format(delta_llama))
464
+ print('value_changed {}'.format(value_changed))
465
+ print('ground_truth_stat {}'.format(ground_truth_stat))
466
+ print('generated_stat_prefinal {}'.format(generated_stat_prefinal))
467
+ print('delta_prefinal {}'.format(delta_prefinal))
468
+ print('generated_stat_final {}'.format(generated_stat_final))
469
+ print('delta_final {}'.format(delta_final))
470
+ print('final_response {}'.format(final_response))
471
+ """
472
+ return pd.Series(
473
+ [
474
+ llama_base_stat,
475
+ delta_llama,
476
+ value_changed,
477
+ ground_truth_stat,
478
+ generated_stat_prefinal,
479
+ delta_prefinal,
480
+ generated_stat_final,
481
+ delta_final,
482
+ final_response,
483
+ ]
484
+ )
485
+
486
+
487
+
488
+ def set_hf_token(token_path):
489
+ # hugging face token
490
+ with open(token_path, "r") as hf_token_file:
491
+ HF_TOKEN = hf_token_file.read().strip()
492
+ HfFolder.save_token(HF_TOKEN)
493
+
494
+
495
+ def get_final_columns():
496
+
497
+ # colnames for final output CSV
498
+ final_columns = [
499
+ "questions",
500
+ "intent",
501
+ "llama_base_output",
502
+ "helper_output",
503
+ "cancer_entities",
504
+ "gene_entities",
505
+ "mutation_entities",
506
+ "modified_prompt",
507
+ "ground_truth_stat",
508
+ "llama_base_stat",
509
+ "delta_llama",
510
+ "final_response",
511
+ ]
512
+ return final_columns
513
+
514
+
515
+ def timeit(fn):
516
+ @wraps(fn)
517
+ def wrapper(*args, **kwargs):
518
+ start = time.perf_counter()
519
+ result = fn(*args, **kwargs)
520
+ end = time.perf_counter()
521
+ print(f"{fn.__name__} took {end - start:.4f} seconds")
522
+ return result
523
+ return wrapper