table_test / load_model.py
Kushalguptaiitb's picture
Upload 2 files
9fc21e6 verified
# from ultralytics import YOLOv10
import torch
from config.set_config import set_configuration
set_config_project = set_configuration()
layout_model_weights_path = set_config_project.layout_model_weights_path
no_of_threads = set_config_project.no_of_threads
# def load_model_for_process(detection_model_path=layout_model_weights_path):
# """
# Load model in each subprocess to avoid CUDA initialization issues
# Returns:
# Model loaded in appropriate device
# """
# # Your model loading logic
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # print(f"Using device: {device}")
# model = YOLOv10(detection_model_path).to(device)
# class_names = model.names
# class_names["11"] = "Table-header"
# class_names["12"] = "Portfolio-Company-Table"
# return model, class_names
import torch
from ultralytics import YOLO
# def load_model_for_process(detection_model_path=layout_model_weights_path):
# """
# Load model in each subprocess to avoid CUDA initialization issues
# Returns:
# Model loaded in appropriate device
# """
# # Your model loading logic
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # print(f"Using device: {device}")
# model = YOLO(detection_model_path).to(device)
# class_names = model.names
# class_names["11"] = "Table-header"
# class_names["12"] = "Portfolio-Company-Table"
# print("YOLOV12"*10)
# return model, class_names
'''Below code for docling heron model'''
from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor
# MODEL_NAME_DOCLING = "ds4sd/docling-layout-heron"
MODEL_NAME_DOCLING = layout_model_weights_path
def load_model_for_process(model_name=MODEL_NAME_DOCLING):
"""
Load the Docling Heron model and image processor in each subprocess to avoid CUDA initialization issues.
Returns:
Tuple of (model, image_processor, class_names)
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the image processor and model
image_processor = RTDetrImageProcessor.from_pretrained(model_name)
model = RTDetrV2ForObjectDetection.from_pretrained(model_name).to(device)
# Define class names mapping
class_names = {
0: "Caption",
1: "Footnote",
2: "Formula",
3: "List-item",
4: "Page-footer",
5: "Page-header",
6: "Picture",
7: "Section-header",
8: "Table",
9: "Text",
10: "Title",
11: "Document Index",
12: "Code",
13: "Checkbox-Selected",
14: "Checkbox-Unselected",
15: "Form",
16: "Key-Value Region",
# Additional classes for compatibility with existing pipeline
17 : "Table-header",
18 : "Portfolio-Company-Table"
}
return model, image_processor, class_names