vikramvasudevan commited on
Commit
fc44b9f
·
verified ·
1 Parent(s): 3b62c7b

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. graph.py +50 -26
.gitignore CHANGED
@@ -10,4 +10,5 @@ wheels/
10
  .venv
11
  .env
12
  data/jg/
13
- output/
 
 
10
  .venv
11
  .env
12
  data/jg/
13
+ output/
14
+ data/jg_2/
graph.py CHANGED
@@ -88,7 +88,7 @@ class SheamiState(TypedDict):
88
  standardized_reports: List[StandardizedReport]
89
  trends_json: dict
90
  interpreted_report: str
91
-
92
 
93
  import re
94
 
@@ -152,34 +152,46 @@ def fn_init_node(state: SheamiState):
152
  state["standardized_reports"] = []
153
  state["trends_json"] = {}
154
  state["interpreted_report"] = ""
 
 
 
 
 
 
155
  return state
156
 
 
 
 
 
 
 
157
 
158
- async def fn_standardizer_node(state: SheamiState):
159
- logger.info("%s| Standardizing reports: started", state["thread_id"])
160
- state["messages"].append("Standardizing reports: started")
161
  llm_structured = llm.with_structured_output(StandardizedReport)
162
- for idx, report in enumerate(state["uploaded_reports"]):
163
- logger.info("%s| Standardizing report %s", state["thread_id"], report.report_file_name)
164
- state["messages"].append(f"Standardizing report: {report.report_file_name}")
165
- messages = [
166
- SystemMessage(content="Standardize this medical report into the schema."),
167
- # SystemMessage(
168
- # content="Populate the `inferred_range` field as 'low', 'normal', or 'high' by comparing the result value with the reference range. If both min and max are missing, set 'normal' unless the value is clearly out of usual medical ranges."
169
- # ),
170
- HumanMessage(content=report.report_contents),
171
- ]
172
- result: StandardizedReport = llm_structured.invoke(messages)
173
- state["standardized_reports"].append(result)
174
- # save to disk
175
- with open(
176
- os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"), "w"
177
- ) as f:
178
- f.write(result.model_dump_json(indent=2))
179
- logger.info("%s| Standardizing reports: finished", state["thread_id"])
180
- state["messages"].append("Standardizing reports: finished")
181
  return state
182
 
 
 
 
 
 
 
 
 
 
183
 
184
  async def fn_testname_standardizer_node(state: SheamiState):
185
  logger.info("%s| Standardizing Test Names: started", state["thread_id"])
@@ -470,7 +482,8 @@ def create_graph(thread_id : str):
470
  memory = InMemorySaver()
471
  workflow = StateGraph(SheamiState)
472
  workflow.add_node("init", fn_init_node)
473
- workflow.add_node("standardizer", fn_standardizer_node)
 
474
  workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
475
  workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
476
  workflow.add_node("trends", fn_trends_aggregator_node)
@@ -485,8 +498,19 @@ def create_graph(thread_id : str):
485
 
486
  workflow.add_edge(START, "init")
487
  workflow.add_edge("init", "standardizer_notifier")
488
- workflow.add_edge("standardizer_notifier","standardizer")
489
- workflow.add_edge("standardizer", "testname_standardizer_notifier")
 
 
 
 
 
 
 
 
 
 
 
490
  workflow.add_edge("testname_standardizer_notifier","testname_standardizer")
491
  workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
492
  workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")
 
88
  standardized_reports: List[StandardizedReport]
89
  trends_json: dict
90
  interpreted_report: str
91
+ current_index: int
92
 
93
  import re
94
 
 
152
  state["standardized_reports"] = []
153
  state["trends_json"] = {}
154
  state["interpreted_report"] = ""
155
+ state["current_index"] = -1
156
+ return state
157
+
158
+
159
+ async def fn_increment_index_node(state: SheamiState):
160
+ state["current_index"] += 1
161
  return state
162
 
163
+ async def fn_standardizer_node_one(state: SheamiState):
164
+ idx = state["current_index"]
165
+ report = state["uploaded_reports"][idx]
166
+
167
+ logger.info("%s| Standardizing report %s", state["thread_id"], report.report_file_name)
168
+ state["messages"].append(f"Standardizing report: {report.report_file_name}")
169
 
 
 
 
170
  llm_structured = llm.with_structured_output(StandardizedReport)
171
+ messages = [
172
+ SystemMessage(content="Standardize this medical report into the schema."),
173
+ HumanMessage(content=report.report_contents),
174
+ ]
175
+ result: StandardizedReport = await llm_structured.ainvoke(messages)
176
+
177
+ state["standardized_reports"].append(result)
178
+
179
+ with open(
180
+ os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"), "w"
181
+ ) as f:
182
+ f.write(result.model_dump_json(indent=2))
183
+
 
 
 
 
 
 
184
  return state
185
 
186
+ # edge
187
+ def fn_check_if_report_available_to_process(state: SheamiState) -> str:
188
+ if state["current_index"] < len(state["uploaded_reports"]):
189
+ report = state["uploaded_reports"][state["current_index"]]
190
+ state["messages"].append(f"Initiating report standardization for: {report.report_file_name}")
191
+ return "continue"
192
+ else:
193
+ state["messages"].append("Standardizing reports: finished")
194
+ return "done"
195
 
196
  async def fn_testname_standardizer_node(state: SheamiState):
197
  logger.info("%s| Standardizing Test Names: started", state["thread_id"])
 
482
  memory = InMemorySaver()
483
  workflow = StateGraph(SheamiState)
484
  workflow.add_node("init", fn_init_node)
485
+ workflow.add_node("standardizer_one", fn_standardizer_node_one)
486
+ workflow.add_node("increment_index", fn_increment_index_node)
487
  workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
488
  workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
489
  workflow.add_node("trends", fn_trends_aggregator_node)
 
498
 
499
  workflow.add_edge(START, "init")
500
  workflow.add_edge("init", "standardizer_notifier")
501
+ workflow.add_edge("standardizer_notifier","increment_index")
502
+
503
+ # loop back if continue
504
+ workflow.add_conditional_edges(
505
+ "increment_index",
506
+ fn_check_if_report_available_to_process,
507
+ {
508
+ "continue": "standardizer_one",
509
+ "done": "testname_standardizer_notifier",
510
+ }
511
+ )
512
+ workflow.add_edge("standardizer_one", "increment_index")
513
+
514
  workflow.add_edge("testname_standardizer_notifier","testname_standardizer")
515
  workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
516
  workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")