Spaces:
Runtime error
Runtime error
Commit
·
e197ad0
1
Parent(s):
b123ef7
update: integrate FigureAnnotatorFromPageImage into MedQAAssistant
Browse files
medrag_multi_modal/assistant/figure_annotation.py
CHANGED
|
@@ -92,44 +92,48 @@ Here are some clues you need to follow:
|
|
| 92 |
)
|
| 93 |
|
| 94 |
@weave.op()
|
| 95 |
-
def predict(self, image_artifact_address: str):
|
| 96 |
"""
|
| 97 |
-
Predicts figure annotations for
|
| 98 |
|
| 99 |
-
This function retrieves
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
`extract_structured_output` method and
|
| 105 |
|
| 106 |
Args:
|
| 107 |
-
|
|
|
|
| 108 |
|
| 109 |
Returns:
|
| 110 |
-
|
|
|
|
| 111 |
"""
|
| 112 |
artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
| 113 |
metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
|
| 114 |
-
annotations =
|
| 115 |
for item in track(metadata, description="Annotating images:"):
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
)
|
| 120 |
-
if len(figure_image_files) > 0:
|
| 121 |
-
page_image = cv2.imread(page_image_file)
|
| 122 |
-
page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
|
| 123 |
-
page_image = Image.fromarray(page_image)
|
| 124 |
-
figure_extracted_annotations = self.annotate_figures(
|
| 125 |
-
page_image=page_image
|
| 126 |
)
|
| 127 |
-
|
| 128 |
-
{
|
| 129 |
-
"page_idx": item["page_idx"],
|
| 130 |
-
"annotations": self.extract_structured_output(
|
| 131 |
-
figure_extracted_annotations["annotations"]
|
| 132 |
-
).model_dump(),
|
| 133 |
-
}
|
| 134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
return annotations
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
@weave.op()
|
| 95 |
+
def predict(self, page_idx: int, image_artifact_address: str):
|
| 96 |
"""
|
| 97 |
+
Predicts figure annotations for a specific page in a document.
|
| 98 |
|
| 99 |
+
This function retrieves the artifact directory from the given image artifact address,
|
| 100 |
+
reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
|
| 101 |
+
to find the specified page index. If the page index matches, it reads the page image
|
| 102 |
+
and associated figure images, and then uses the `annotate_figures` method to extract
|
| 103 |
+
figure annotations from the page image. The extracted annotations are then structured
|
| 104 |
+
using the `extract_structured_output` method and returned as a dictionary.
|
| 105 |
|
| 106 |
Args:
|
| 107 |
+
page_idx (int): The index of the page to annotate.
|
| 108 |
+
image_artifact_address (str): The address of the image artifact containing the page images.
|
| 109 |
|
| 110 |
Returns:
|
| 111 |
+
dict: A dictionary containing the page index as the key and the extracted figure annotations
|
| 112 |
+
as the value.
|
| 113 |
"""
|
| 114 |
artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
| 115 |
metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
|
| 116 |
+
annotations = {}
|
| 117 |
for item in track(metadata, description="Annotating images:"):
|
| 118 |
+
if item["page_idx"] == page_idx:
|
| 119 |
+
page_image_file = os.path.join(
|
| 120 |
+
artifact_dir, f"page{item['page_idx']}.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
+
figure_image_files = glob(
|
| 123 |
+
os.path.join(artifact_dir, f"page{item['page_idx']}_fig*.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
)
|
| 125 |
+
if len(figure_image_files) > 0:
|
| 126 |
+
page_image = cv2.imread(page_image_file)
|
| 127 |
+
page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
|
| 128 |
+
page_image = Image.fromarray(page_image)
|
| 129 |
+
figure_extracted_annotations = self.annotate_figures(
|
| 130 |
+
page_image=page_image
|
| 131 |
+
)
|
| 132 |
+
figure_extracted_annotations = self.extract_structured_output(
|
| 133 |
+
figure_extracted_annotations["annotations"]
|
| 134 |
+
).model_dump()
|
| 135 |
+
annotations[item["page_idx"]] = figure_extracted_annotations[
|
| 136 |
+
"annotations"
|
| 137 |
+
]
|
| 138 |
+
break
|
| 139 |
return annotations
|
medrag_multi_modal/assistant/medqa_assistant.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
|
|
|
|
|
| 1 |
import weave
|
| 2 |
|
| 3 |
from ..retrieval import SimilarityMetric
|
|
|
|
| 4 |
from .llm_client import LLMClient
|
| 5 |
|
| 6 |
|
|
@@ -9,11 +12,12 @@ class MedQAAssistant(weave.Model):
|
|
| 9 |
|
| 10 |
llm_client: LLMClient
|
| 11 |
retriever: weave.Model
|
|
|
|
| 12 |
top_k_chunks: int = 2
|
| 13 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
| 14 |
|
| 15 |
@weave.op()
|
| 16 |
-
def predict(self, query: str) -> str:
|
| 17 |
retrieved_chunks = self.retriever.predict(
|
| 18 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
| 19 |
)
|
|
@@ -23,13 +27,24 @@ class MedQAAssistant(weave.Model):
|
|
| 23 |
for chunk in retrieved_chunks:
|
| 24 |
retrieved_chunk_texts.append(chunk["text"])
|
| 25 |
page_indices.add(int(chunk["page_idx"]))
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
system_prompt = """
|
| 29 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
| 30 |
"""
|
| 31 |
response = self.llm_client.predict(
|
| 32 |
-
system_prompt=system_prompt,
|
|
|
|
| 33 |
)
|
|
|
|
| 34 |
response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
|
| 35 |
return response
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
import weave
|
| 4 |
|
| 5 |
from ..retrieval import SimilarityMetric
|
| 6 |
+
from .figure_annotation import FigureAnnotatorFromPageImage
|
| 7 |
from .llm_client import LLMClient
|
| 8 |
|
| 9 |
|
|
|
|
| 12 |
|
| 13 |
llm_client: LLMClient
|
| 14 |
retriever: weave.Model
|
| 15 |
+
figure_annotator: FigureAnnotatorFromPageImage
|
| 16 |
top_k_chunks: int = 2
|
| 17 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
| 18 |
|
| 19 |
@weave.op()
|
| 20 |
+
def predict(self, query: str, image_artifact_address: Optional[str] = None) -> str:
|
| 21 |
retrieved_chunks = self.retriever.predict(
|
| 22 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
| 23 |
)
|
|
|
|
| 27 |
for chunk in retrieved_chunks:
|
| 28 |
retrieved_chunk_texts.append(chunk["text"])
|
| 29 |
page_indices.add(int(chunk["page_idx"]))
|
| 30 |
+
|
| 31 |
+
figure_descriptions = []
|
| 32 |
+
if image_artifact_address is not None:
|
| 33 |
+
for page_idx in page_indices:
|
| 34 |
+
figure_annotations = self.figure_annotator.predict(
|
| 35 |
+
page_idx=page_idx, image_artifact_address=image_artifact_address
|
| 36 |
+
)
|
| 37 |
+
figure_descriptions += [
|
| 38 |
+
item["figure_description"] for item in figure_annotations[page_idx]
|
| 39 |
+
]
|
| 40 |
|
| 41 |
system_prompt = """
|
| 42 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
| 43 |
"""
|
| 44 |
response = self.llm_client.predict(
|
| 45 |
+
system_prompt=system_prompt,
|
| 46 |
+
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
| 47 |
)
|
| 48 |
+
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
|
| 49 |
response += f"\n\n**Source:** {'Pages' if len(page_numbers) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
|
| 50 |
return response
|