Spaces:
Runtime error
Runtime error
| from ultralytics import YOLO | |
| import supervision as sv | |
| import cv2 | |
| import gradio as gr | |
| import os | |
| import numpy as np | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| import torch | |
| import requests | |
| from PIL import Image | |
| import glob | |
| import pandas as pd | |
| import time | |
| from pdf2image import convert_from_path | |
| import pymupdf | |
| import camelot | |
| import numpy as np | |
| import fitz | |
| import subprocess | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True).to(device).eval() | |
| processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True) | |
| onnx_model = YOLO("models/best.onnx", task='detect') | |
| onnx_model_table = YOLO("models/tables/best.onnx", task='detect') | |
| def filter_detections(detections, target_class_name="mark"): | |
| indices_to_keep = [i for i, class_name in enumerate(detections.data['class_name']) if | |
| class_name == target_class_name] | |
| filtered_xyxy = detections.xyxy[indices_to_keep] | |
| filtered_confidence = detections.confidence[indices_to_keep] | |
| filtered_class_id = detections.class_id[indices_to_keep] | |
| filtered_class_name = detections.data['class_name'][indices_to_keep] | |
| detections.xyxy = filtered_xyxy | |
| detections.confidence = filtered_confidence | |
| detections.class_id = filtered_class_id | |
| detections.data['class_name'] = filtered_class_name | |
| return detections | |
| def add_label_detection(detections): | |
| updated_class = [f"{class_name} {i + 1}" for i, class_name in enumerate(detections.data['class_name'])] | |
| updated_id = [class_id + i for i, class_id in enumerate(detections.class_id)] | |
| detections.data['class_name'] = np.array(updated_class) | |
| detections.class_id = np.array(updated_id) | |
| return detections | |
| def ends_with_number(s): | |
| return s[-1].isdigit() | |
| def ocr(image, prompt="<OCR>"): | |
| original_height, original_width = image.shape[:2] | |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3 | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = processor.post_process_generation( | |
| generated_text, | |
| task=prompt, | |
| # image_size=(image.width, image.height) | |
| image_size=(original_width, original_height) | |
| ) | |
| return parsed_answer | |
| def parse_detection(detections): | |
| parsed_rows = [] | |
| for i in range(len(detections.xyxy)): | |
| x_min = float(detections.xyxy[i][0]) | |
| y_min = float(detections.xyxy[i][1]) | |
| x_max = float(detections.xyxy[i][2]) | |
| y_max = float(detections.xyxy[i][3]) | |
| width = int(x_max - x_min) | |
| height = int(y_max - y_min) | |
| row = { | |
| "top": int(y_min), | |
| "left": int(x_min), | |
| "width": width, | |
| "height": height, | |
| "class_id": "" | |
| if detections.class_id is None | |
| else int(detections.class_id[i]), | |
| "confidence": "" | |
| if detections.confidence is None | |
| else float(detections.confidence[i]), | |
| "tracker_id": "" | |
| if detections.tracker_id is None | |
| else int(detections.tracker_id[i]), | |
| } | |
| if hasattr(detections, "data"): | |
| for key, value in detections.data.items(): | |
| row[key] = ( | |
| str(value[i]) | |
| if hasattr(value, "__getitem__") and value.ndim != 0 | |
| else str(value) | |
| ) | |
| parsed_rows.append(row) | |
| return parsed_rows | |
| def cut_and_save_image(image, parsed_detections, output_dir): | |
| output_path_list = [] | |
| for i, det in enumerate(parsed_detections): | |
| # Check if the class is 'mark' | |
| if det['class_name'] == 'mark': | |
| top = det['top'] | |
| left = det['left'] | |
| width = det['width'] | |
| height = det['height'] | |
| # Cut the image | |
| cut_image = image[top:top + height, left:left + width] | |
| # Save the image | |
| output_path = f"{output_dir}/cut_image_{i}.png" | |
| scaled_image = sv.scale_image(image=cut_image, scale_factor=4) | |
| cv2.imwrite(output_path, scaled_image, [int(cv2.IMWRITE_JPEG_QUALITY), 500]) | |
| output_path_list.append(output_path) | |
| return output_path_list | |
| def analysis(progress=gr.Progress()): | |
| progress(0, desc="Analyzing...") | |
| list_files = glob.glob("output/*.png") | |
| prompt = "<OCR>" | |
| results = {} | |
| for filepath in progress.tqdm(list_files): | |
| basename = os.path.basename(filepath) | |
| image = cv2.imread(filepath) | |
| start_time = time.time() | |
| parsed_answer = ocr(image, prompt) | |
| if not ends_with_number(parsed_answer[prompt]): | |
| parsed_answer[prompt] += "1" | |
| results[parsed_answer[prompt]] = results.get(parsed_answer[prompt], 0) + 1 | |
| print(basename, parsed_answer[prompt]) | |
| print("Time taken:", time.time() - start_time) | |
| return pd.DataFrame(results.items(), columns=['Mark', 'Total']).reset_index(drop=False).rename(columns={'index': 'No.'}) | |
| def inference( | |
| image_path, | |
| conf_threshold, | |
| iou_threshold, | |
| ): | |
| """ | |
| YOLOv8 inference function | |
| Args: | |
| image_path: Path to the image | |
| conf_threshold: Confidence threshold | |
| iou_threshold: IoU threshold | |
| Returns: | |
| Rendered image | |
| """ | |
| image = cv2.imread(image_path) | |
| original_height, original_width = image.shape[:2] | |
| print(image.shape) | |
| results = onnx_model(image, conf=conf_threshold, iou=iou_threshold)[0] | |
| detections = sv.Detections.from_ultralytics(results) | |
| detections = filter_detections(detections) | |
| parsed_detections = parse_detection(detections) | |
| output_dir = "output" | |
| # Check if the output directory exists, clear all the files inside | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| else: | |
| for f in os.listdir(output_dir): | |
| os.remove(os.path.join(output_dir, f)) | |
| output_path_list = cut_and_save_image(image, parsed_detections, output_dir) | |
| box_annotator = sv.BoxAnnotator() | |
| label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2) | |
| annotated_image = image.copy() | |
| annotated_image = box_annotator.annotate( | |
| scene=annotated_image, | |
| detections=detections | |
| ) | |
| annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) | |
| return annotated_image, output_path_list | |
| def read_table(sheet): | |
| excel_path = "output_tables.xlsx" | |
| if os.path.exists(excel_path): | |
| sheetnames = pd.ExcelFile(excel_path).sheet_names | |
| if sheet in sheetnames: | |
| df = pd.read_excel(excel_path, sheet_name=sheet) | |
| else: | |
| df = pd.DataFrame() | |
| else: | |
| df = pd.DataFrame() | |
| return df | |
| def validate_df(df): | |
| columns = [] | |
| count = 1 | |
| for col in df.columns: | |
| if type(col) == int: | |
| columns.append(f"Col {count}") | |
| count += 1 | |
| else: | |
| columns.append(col) | |
| df.columns = columns | |
| return df | |
| def analyze_table(file, conf_threshold, iou_threshold, progress=gr.Progress()): | |
| progress(0, desc="Parsing table...") | |
| img = convert_from_path(file)[0] | |
| doc = pymupdf.open(file) | |
| zoom_x = 1.0 # horizontal zoom | |
| zoom_y = 1.0 # vertical zoom | |
| mat = pymupdf.Matrix(zoom_x, zoom_y) | |
| for i, page in enumerate(doc): | |
| pix = page.get_pixmap(matrix=mat) | |
| pix.save("temp.png") | |
| image = cv2.imread("temp.png") | |
| file_height, file_width, _ = image.shape | |
| results = onnx_model_table(image, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0] | |
| detections = sv.Detections.from_ultralytics(results) | |
| detections = add_label_detection(detections) | |
| parsed_detections = parse_detection(detections) | |
| # print(parsed_detections) | |
| output_dir = "output_table" | |
| # Check if the output directory exists, clear all the files inside | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| else: | |
| for f in os.listdir(output_dir): | |
| os.remove(os.path.join(output_dir, f)) | |
| box_annotator = sv.BoxAnnotator() | |
| label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2) | |
| annotated_image = image.copy() | |
| annotated_image = box_annotator.annotate( | |
| scene=annotated_image, | |
| detections=detections | |
| ) | |
| annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) | |
| pdf = fitz.open(file) | |
| pdf_page = pdf[0] | |
| table_area = [(ind, | |
| fitz.Rect(det['left'], det['top'], det['left'] + det['width'], det['top'] + det['height'])) | |
| for ind, det in enumerate(parsed_detections) | |
| ] | |
| table_list = [] | |
| for ind, area in progress.tqdm(table_area): | |
| pdf_tabs = pdf_page.find_tables(clip=area) | |
| if len(pdf_tabs.tables) > 0: | |
| pdf_df = pdf_tabs[0].to_pandas() | |
| print("Fitz Table Found!") | |
| else: | |
| cur = parsed_detections[ind] | |
| table_areas = [f"{cur['left']},{file_height - cur['top']},{cur['left'] + cur['width']},{file_height - (cur['top'] + cur['height'])}"] | |
| tables = camelot.read_pdf(file, pages='0', flavor='stream', row_tol=10, table_areas=table_areas) | |
| pdf_df = tables[0].df | |
| print("Camelot Table Found!") | |
| pdf_df = validate_df(pdf_df) | |
| table_list.append(pdf_df) | |
| excel_path = "output_tables.xlsx" | |
| sheet_list = [] | |
| with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer: | |
| for i in range(len(table_list)): | |
| sheet_name = f"Table_{i + 1}" | |
| table_list[i].to_excel(writer, sheet_name=sheet_name, index=False) | |
| sheet_list.append(sheet_name) | |
| return img, annotated_image, excel_path, ", ".join(sheet_list) | |
| TITLE = "<h1 style='font-size: 2.5em; text-align: center;'>Identify objects in construction design</h1>" | |
| DESCRIPTION = """<p style='font-size: 1.5em; line-height: 1.6em; text-align: left;'>Welcome to the object | |
| identification application. This tool allows you to upload an image, and it will identify and annotate objects within | |
| the image. Additionally, you can perform OCR analysis on the detected objects.</p> | |
| """ | |
| CSS = """ | |
| #output { | |
| height: 500px; | |
| overflow: auto; | |
| border: 1px solid #ccc; | |
| } | |
| h1 { | |
| text-align: center; | |
| } | |
| """ | |
| EXAMPLES = [ | |
| ['examples/train1.png', 0.6, 0.25], | |
| ['examples/train2.png', 0.9, 0.25], | |
| ['examples/train3.png', 0.6, 0.25] | |
| ] | |
| SHEET_LIST = ['Table_1', 'Table_2', 'Table_3', 'Table_4', 'Table_5', 'Table_6'] | |
| with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo: | |
| gr.HTML(TITLE) | |
| gr.HTML(DESCRIPTION) | |
| with gr.Tab(label="Identify objects"): | |
| with gr.Row(equal_height=False): | |
| input_img = gr.Image(type="filepath", label="Upload Image") | |
| output_img = gr.Image(type="filepath", label="Output Image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| conf_thres = gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold") | |
| with gr.Column(): | |
| iou = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold") | |
| with gr.Row(): | |
| with gr.Column(): | |
| submit_btn = gr.Button(value="Predict") | |
| with gr.Column(): | |
| analysis_btn = gr.Button(value="Analysis") | |
| with gr.Row(): | |
| output_df = gr.Dataframe(label="Results") | |
| with gr.Row(): | |
| with gr.Accordion("Gallery", open=False): | |
| gallery = gr.Gallery(label="Detected Mark Object", columns=3) | |
| submit_btn.click(inference, [input_img, conf_thres, iou], [output_img, gallery]) | |
| analysis_btn.click(analysis, [], [output_df]) | |
| examples = gr.Examples( | |
| EXAMPLES, | |
| fn=inference, | |
| inputs=[input_img, conf_thres, iou], | |
| outputs=[output_img, gallery], | |
| cache_examples=False, | |
| ) | |
| with gr.Tab(label="Detect and read table"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_pdf = gr.Image(label="Upload PDF file") | |
| upload_button = gr.UploadButton(label="Upload PDF file", file_types=[".pdf"]) | |
| with gr.Column(): | |
| output_img = gr.Image(label="Output Image", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| conf_thres_table = gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, | |
| label="Confidence Threshold") | |
| with gr.Column(): | |
| iou_table = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Table List") | |
| with gr.Column(): | |
| file_output = gr.File() | |
| with gr.Row(): | |
| sheet_name = gr.Dropdown(choices=SHEET_LIST, allow_custom_value=True, label="Sheet Name") | |
| with gr.Row(): | |
| output_df = gr.Dataframe(label="Results") | |
| upload_button.upload(analyze_table, [upload_button, conf_thres_table, iou_table], | |
| [upload_pdf, output_img, file_output, text_output]) | |
| conf_thres_table.change(analyze_table, [upload_button, conf_thres_table, iou_table], | |
| [upload_pdf, output_img, file_output, text_output]) | |
| iou_table.change(analyze_table, [upload_button, conf_thres_table, iou_table], | |
| [upload_pdf, output_img, file_output, text_output]) | |
| sheet_name.change(read_table, sheet_name, output_df) | |
| demo.launch(debug=True) |