nkarthikeyan commited on
Commit
83940f8
·
verified ·
1 Parent(s): 42b7f70

Upload mega_detector.py

Browse files
Files changed (1) hide show
  1. mega_detector.py +16 -11
mega_detector.py CHANGED
@@ -10,22 +10,22 @@ from abc import ABC, abstractmethod
10
  class BaseModel(ABC):
11
  @abstractmethod
12
  def pre_process(self, filename: str):
13
- """
14
- Pre-process the input file and return it as a tensor.
15
- """
16
  pass
17
 
18
  @abstractmethod
19
  def predict(self, input_data):
20
- """
21
- Run inference on the pre-processed input and return predictions.
22
- """
23
  pass
24
 
25
  class MegaDetectorModel(BaseModel):
26
  """
27
- MegaDetectorModel loads the MegaDetector V5 checkpoint from Hugging Face,
28
  preprocesses input images, runs inference, and returns detections (label/confidence).
 
 
 
 
29
  """
30
 
31
  def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None):
@@ -41,8 +41,11 @@ class MegaDetectorModel(BaseModel):
41
  @classmethod
42
  def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs):
43
  """
44
- Loads the model checkpoint from the given Hugging Face repository and returns
45
- an instance of MegaDetectorModel ready for inference.
 
 
 
46
 
47
  Args:
48
  repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5").
@@ -52,8 +55,10 @@ class MegaDetectorModel(BaseModel):
52
  MegaDetectorModel: An instance with the model loaded.
53
  """
54
  instance = cls(device=device, **kwargs)
55
- # Download the model checkpoint (assumes the file is named 'model.pt')
56
- model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
 
 
57
  checkpoint = torch.load(model_path, map_location=instance.device)
58
  instance.model = checkpoint['model'].float().fuse().eval()
59
  if instance.device.type != 'cpu':
 
10
  class BaseModel(ABC):
11
  @abstractmethod
12
  def pre_process(self, filename: str):
13
+ """Pre-process the input file and return it as a tensor."""
 
 
14
  pass
15
 
16
  @abstractmethod
17
  def predict(self, input_data):
18
+ """Run inference on the pre-processed input and return predictions."""
 
 
19
  pass
20
 
21
  class MegaDetectorModel(BaseModel):
22
  """
23
+ MegaDetectorModel loads the MegaDetector checkpoint from a Hugging Face repository,
24
  preprocesses input images, runs inference, and returns detections (label/confidence).
25
+
26
+ The repository ID is the only input required. The model filename, class name, and weight file
27
+ are all expected to match the repository's base name. For example, if the repository ID is
28
+ "nkarthikeyan/MegaDetectorV5", then the model weight file should be "MegaDetectorV5.pt".
29
  """
30
 
31
  def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None):
 
41
  @classmethod
42
  def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs):
43
  """
44
+ Loads the model checkpoint from the given Hugging Face repository and returns an instance
45
+ of MegaDetectorModel ready for inference.
46
+
47
+ The repository's base name is used to derive the model weight filename. For example, if
48
+ repo_id is "nkarthikeyan/MegaDetectorV5", then the weight file is expected to be "MegaDetectorV5.pt".
49
 
50
  Args:
51
  repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5").
 
55
  MegaDetectorModel: An instance with the model loaded.
56
  """
57
  instance = cls(device=device, **kwargs)
58
+ # Use the repository base name as the weight filename.
59
+ model_name = repo_id.split("/")[-1]
60
+ weight_filename = f"{model_name}.pt"
61
+ model_path = hf_hub_download(repo_id=repo_id, filename=weight_filename)
62
  checkpoint = torch.load(model_path, map_location=instance.device)
63
  instance.model = checkpoint['model'].float().fuse().eval()
64
  if instance.device.type != 'cpu':