Safi029 commited on
Commit
ea48bf6
·
verified ·
1 Parent(s): 3e279a4

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. __init__.py +3 -0
  2. predict.py +55 -0
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ABD_model/__init__.py
2
+
3
+ from .predict import run_model # if you define a function in predict.py
predict.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import os
4
+ from pathlib import Path
5
+
6
+ def get_model_path():
7
+ """
8
+ Returns the full path to the ABD.pt model file bundled with the package.
9
+ """
10
+ return os.path.join(os.path.dirname(__file__), "ABD.pt")
11
+
12
+ def load_model():
13
+ """
14
+ Load the YOLOv8 model from the local ABD.pt file included in the package.
15
+ """
16
+ weights_path = get_model_path()
17
+
18
+ if not os.path.exists(weights_path):
19
+ raise FileNotFoundError(f"Model weights not found at: {weights_path}")
20
+
21
+ model = torch.hub.load('ultralytics/yolov8', 'custom', path=weights_path, force_reload=False)
22
+ return model
23
+
24
+ def predict_image(model, image_path):
25
+ """
26
+ Run prediction on the given image using the YOLOv8 model.
27
+ """
28
+ if not os.path.exists(image_path):
29
+ raise FileNotFoundError(f"Image file not found: {image_path}")
30
+
31
+ results = model(image_path)
32
+ results.print()
33
+ results.show()
34
+ return results
35
+
36
+ def run_model(image_path):
37
+ """
38
+ Full pipeline: load model and run prediction.
39
+ """
40
+ model = load_model()
41
+ results = predict_image(model, image_path)
42
+ return results
43
+
44
+ if __name__ == "__main__":
45
+ import argparse
46
+
47
+ parser = argparse.ArgumentParser(description="Predict atoms and bonds from a molecular image.")
48
+ parser.add_argument("--input_path", type=str, required=True, help="Path to the image (.png, .jpg, etc.)")
49
+
50
+ args = parser.parse_args()
51
+
52
+ try:
53
+ run_model(args.input_path)
54
+ except Exception as e:
55
+ print(f"Error: {e}")