| import torch | |
| import cv2 | |
| import os | |
| from pathlib import Path | |
| def get_model_path(): | |
| """ | |
| Returns the full path to the ABD.pt model file bundled with the package. | |
| """ | |
| return os.path.join(os.path.dirname(__file__), "ABD.pt") | |
| def load_model(): | |
| """ | |
| Load the YOLOv8 model from the local ABD.pt file included in the package. | |
| """ | |
| weights_path = get_model_path() | |
| if not os.path.exists(weights_path): | |
| raise FileNotFoundError(f"Model weights not found at: {weights_path}") | |
| model = torch.hub.load('ultralytics/yolov8', 'custom', path=weights_path, force_reload=False) | |
| return model | |
| def predict_image(model, image_path): | |
| """ | |
| Run prediction on the given image using the YOLOv8 model. | |
| """ | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image file not found: {image_path}") | |
| results = model(image_path) | |
| results.print() | |
| results.show() | |
| return results | |
| def run_model(image_path): | |
| """ | |
| Full pipeline: load model and run prediction. | |
| """ | |
| model = load_model() | |
| results = predict_image(model, image_path) | |
| return results | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Predict atoms and bonds from a molecular image.") | |
| parser.add_argument("--input_path", type=str, required=True, help="Path to the image (.png, .jpg, etc.)") | |
| args = parser.parse_args() | |
| try: | |
| run_model(args.input_path) | |
| except Exception as e: | |
| print(f"Error: {e}") | |