CV_parsing / src /graph.py
ABAO77's picture
test code
28b686e
from langgraph.graph import StateGraph, END, START, add_messages
from typing import TypedDict, Any, Annotated
from PIL import Image
from src.utils.utils_segment import (
preprocess,
postprocess,
extract_text,
draw_bounding_boxes,
)
from src.inference.segment_inference import model
from src.config.llm import llm
from src.prompt.promt import format_prompt
from langchain_core.output_parsers import JsonOutputParser
parser = JsonOutputParser()
class State(TypedDict):
image: Any
image_origin: Any
outputs_from_inference: Any
text_extracted_from_ocr: Any
threshold_confidence: float
threshold_iou: float
cropped_images: Any
parser_output: bool
image_with_bounding_boxes: Any
_image: Annotated[Any, add_messages]
crop_image: Any
class N:
PRE_PROCESS = "PRE_PROCESS"
POST_PROCESS = "POST_PROCESS"
INFERENCE = "INFERENCE"
EXTRACT_TEXT_FROM_OCR = "EXTRACT_TEXT_FROM_OCR"
PARSER_WITH_LLM = "PARSER_WITH_LLM"
IMAGE_WITH_BOUNDING_BOXES = "IMAGE_WITH_BOUNDING_BOXES"
workflow = StateGraph(State)
def pre_process_fn(state: State):
preprocess_img = preprocess(state["image_origin"])
print("preprocess_img", preprocess_img.shape)
image_for_display = (preprocess_img[0] * 255).astype("uint8")
image_for_display = image_for_display.transpose(1, 2, 0)
image_show = Image.fromarray(image_for_display)
return {"image": preprocess_img, "_image": image_show}
def inference_fn(state: State):
image = state["image"]
outputs = model.run(None, {"images": image})
return {"outputs_from_inference": outputs}
def post_process_fn(state: State):
outputs = state["outputs_from_inference"]
threshold_confidence = state["threshold_confidence"]
threshold_iou = state["threshold_iou"]
post_process_output = postprocess(outputs, threshold_confidence, threshold_iou)
return {
"outputs_from_inference": post_process_output,
}
def extract_text_from_ocr_fn(state: State):
image_origin = state["image_origin"]
output_from_inference = state["outputs_from_inference"]
text = extract_text(output_from_inference, image_origin)
return {"text_extracted_from_ocr": text}
def draw_bounding_boxes_fn(state: State):
image = state["image_origin"]
outputs = state["outputs_from_inference"]
image_with_bounding_boxes = draw_bounding_boxes(image, outputs)
return {"image_with_bounding_boxes": image_with_bounding_boxes}
def parser_output_fn(state: State):
text_extracted_from_ocr = state["text_extracted_from_ocr"]
chain = format_prompt | llm | parser
response = chain.invoke({"user_input": text_extracted_from_ocr})
print(response)
return {"parser_output": response}
#NODE
workflow.add_node(N.PRE_PROCESS, pre_process_fn)
workflow.add_node(N.INFERENCE, inference_fn)
workflow.add_node(N.POST_PROCESS, post_process_fn)
workflow.add_node(N.EXTRACT_TEXT_FROM_OCR, extract_text_from_ocr_fn)
workflow.add_node(N.IMAGE_WITH_BOUNDING_BOXES, draw_bounding_boxes_fn)
workflow.add_node(N.PARSER_WITH_LLM, parser_output_fn)
#EDGE
workflow.add_edge(START, N.PRE_PROCESS)
workflow.add_edge(N.PRE_PROCESS, N.INFERENCE)
workflow.add_edge(N.INFERENCE, N.POST_PROCESS)
workflow.add_edge(N.POST_PROCESS, N.IMAGE_WITH_BOUNDING_BOXES)
workflow.add_edge(N.IMAGE_WITH_BOUNDING_BOXES, N.EXTRACT_TEXT_FROM_OCR)
workflow.add_conditional_edges(
N.EXTRACT_TEXT_FROM_OCR,
lambda state: N.PARSER_WITH_LLM if state["parser_output"] else END,
{
N.PARSER_WITH_LLM: N.PARSER_WITH_LLM,
END: END,
},
)
workflow.add_edge(N.PARSER_WITH_LLM, END)
app = workflow.compile()