ABD-model / predict.py
Safi029's picture
Upload folder using huggingface_hub
ea48bf6 verified
import torch
import cv2
import os
from pathlib import Path
def get_model_path():
"""
Returns the full path to the ABD.pt model file bundled with the package.
"""
return os.path.join(os.path.dirname(__file__), "ABD.pt")
def load_model():
"""
Load the YOLOv8 model from the local ABD.pt file included in the package.
"""
weights_path = get_model_path()
if not os.path.exists(weights_path):
raise FileNotFoundError(f"Model weights not found at: {weights_path}")
model = torch.hub.load('ultralytics/yolov8', 'custom', path=weights_path, force_reload=False)
return model
def predict_image(model, image_path):
"""
Run prediction on the given image using the YOLOv8 model.
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
results = model(image_path)
results.print()
results.show()
return results
def run_model(image_path):
"""
Full pipeline: load model and run prediction.
"""
model = load_model()
results = predict_image(model, image_path)
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Predict atoms and bonds from a molecular image.")
parser.add_argument("--input_path", type=str, required=True, help="Path to the image (.png, .jpg, etc.)")
args = parser.parse_args()
try:
run_model(args.input_path)
except Exception as e:
print(f"Error: {e}")