andrewromanenco commited on
Commit
35290ca
·
0 Parent(s):

Add Hugging Face-ready wrapper for HitDetector model

Browse files

This adds a standalone script for running inference with the HitDetector model,
originally trained using code from:

https://github.com/andrewromanenco/hit-detector

The wrapper is Hugging Face-compatible and includes a pipeline interface
for integration with the Hugging Face Hub and Spaces.

Files changed (9) hide show
  1. LICENSE.txt +21 -0
  2. README.md +80 -0
  3. app.py +14 -0
  4. example.png +0 -0
  5. input.png +0 -0
  6. model.py +29 -0
  7. pipeline.py +79 -0
  8. requirements.txt +6 -0
  9. test_pipeline.py +9 -0
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Andrew Romanenco <andrew@romanenco.com>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: hitdetector
4
+ pipeline_tag: image-classification
5
+ tags:
6
+ - pytorch
7
+ - sliding-window
8
+ - computer-vision
9
+ - hole-detection
10
+ - custom-pipeline
11
+ ---
12
+
13
+ # 🎯 Hit Detector Model
14
+
15
+ This PyTorch-based CNN detects holes on boards or paper using a sliding window approach. It was trained on image patches of size **24×24**. The model scans larger images with this patch size to detect regions of interest.
16
+
17
+ > Holes or defects must approximately fit within a 20×20 region to be accurately detected.
18
+
19
+ ## 📥 Model Inputs & Outputs
20
+
21
+ - Input: RGB or grayscale image (PIL.Image)
22
+ - Output: Annotated PIL.Image with red (or specified color) squares highlighting detected holes
23
+
24
+ ![Example result](example.png)
25
+
26
+ ## 🚀 Quick Start
27
+
28
+ ### 🧠 Inference in Python
29
+
30
+ ```python
31
+ from PIL import Image
32
+ from pipeline import HitDetectorPipeline
33
+
34
+ pipe = HitDetectorPipeline("model.pt")
35
+
36
+ img = Image.open("input.png")
37
+ result = pipe(img)
38
+ result.save("output.png")
39
+ print("✅ Output saved to output.png")
40
+ ```
41
+
42
+ ### 📦 Installation
43
+
44
+ ```bash
45
+ pip install -r requirements.txt
46
+ ```
47
+ ## 🧪 Testing in Docker
48
+
49
+ To test the model or pipeline scripts inside a clean container:
50
+
51
+ ```bash
52
+ cd <project folder>
53
+ docker run -it --rm -p 7860:7860 -v $PWD:/appx:rw romanenco/python-tool-chain /bin/bash
54
+ cd /appx
55
+ pip install -r requirements.txt
56
+ python test_pipeline.py
57
+ ```
58
+
59
+ You should see output.png generated as a result.
60
+
61
+ ### 🌐 Run Gradio UI
62
+
63
+ ```bash
64
+ pip install gradio
65
+ python app.py
66
+ ```
67
+
68
+ Open [http://127.0.0.1:7860](http://127.0.0.1:7860) to test the interactive web UI.
69
+
70
+ ## 🛠 Retrain or Fine-Tune
71
+
72
+ To retrain the model on your own dataset, use the full pipeline and tools from the [main training repo](https://github.com/andrewromanenco/hit-detector), which includes:
73
+
74
+ - 📁 Tools to extract training patches from full images
75
+ - 🧠 Training script
76
+ - 📈 Inference script
77
+
78
+ ## 📄 License
79
+
80
+ MIT
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from pipeline import HitDetectorPipeline
4
+
5
+ pipe = HitDetectorPipeline("model.pt")
6
+
7
+ def detect(image: Image.Image):
8
+ return pipe(image)
9
+
10
+ gr.Interface(
11
+ fn=detect,
12
+ inputs=gr.Image(type="pil"),
13
+ outputs=gr.Image(type="pil"),
14
+ title="Hit Detector").launch(server_name="0.0.0.0", server_port=7860)
example.png ADDED
input.png ADDED
model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class SimpleCNN(nn.Module):
5
+ def __init__(self, sample_input):
6
+ super().__init__()
7
+ self.features = nn.Sequential(
8
+ nn.Conv2d(1, 16, 3, padding=1),
9
+ nn.ReLU(),
10
+ nn.MaxPool2d(2),
11
+ nn.Conv2d(16, 32, 3, padding=1),
12
+ nn.ReLU(),
13
+ nn.MaxPool2d(2),
14
+ )
15
+
16
+ with torch.no_grad():
17
+ dummy_output = self.features(sample_input.unsqueeze(0))
18
+ self.flattened_size = dummy_output.view(1, -1).size(1)
19
+
20
+ self.classifier = nn.Sequential(
21
+ nn.Flatten(),
22
+ nn.Linear(self.flattened_size, 64),
23
+ nn.ReLU(),
24
+ nn.Linear(64, 1)
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = self.features(x)
29
+ return self.classifier(x)
pipeline.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ from model import SimpleCNN
5
+
6
+ PATCH_SIZE = 24
7
+
8
+ def hex_to_rgb(hex_color):
9
+ hex_color = hex_color.strip("#")
10
+ return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
11
+
12
+ def load_model(model_path):
13
+ sample_input = torch.randn(1, PATCH_SIZE, PATCH_SIZE)
14
+ model = SimpleCNN(sample_input)
15
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
16
+ model.eval()
17
+ return model, PATCH_SIZE
18
+
19
+ def run_inference(
20
+ model: torch.nn.Module,
21
+ image: Image.Image,
22
+ original: Image.Image,
23
+ color: tuple,
24
+ opacity: int,
25
+ target_label: int,
26
+ patch_size: int,
27
+ stride: int = 4
28
+ ):
29
+ transform = transforms.ToTensor()
30
+
31
+ width, height = image.size
32
+ total_patches = ((width - patch_size) // stride + 1) * ((height - patch_size) // stride + 1)
33
+
34
+ overlay = Image.new("RGBA", original.size, (0, 0, 0, 0))
35
+
36
+ done = 0
37
+ last_percent_reported = -1
38
+ for y in range(0, height - patch_size + 1, stride):
39
+ for x in range(0, width - patch_size + 1, stride):
40
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
41
+ tensor = transform(patch).unsqueeze(0)
42
+ with torch.no_grad():
43
+ pred = model(tensor)
44
+ predicted_label = int(pred.item() > 0.9)
45
+
46
+ if predicted_label == target_label:
47
+ patch_overlay = Image.new("RGBA", (patch_size, patch_size), color + (opacity,))
48
+ overlay.paste(patch_overlay, (x, y), patch_overlay)
49
+
50
+ done += 1
51
+ percent = int(done / total_patches * 100)
52
+ if percent != last_percent_reported:
53
+ print(f"\rProgress: {percent:3d}% ", end="", flush=True)
54
+ last_percent_reported = percent
55
+
56
+ print("\nDone.")
57
+ blended = Image.alpha_composite(original.convert("RGBA"), overlay)
58
+ return blended.convert("RGB")
59
+
60
+
61
+ class HitDetectorPipeline:
62
+ def __init__(self, model_path="model.pt", color="#FF0000", opacity=128, target_label=1):
63
+ self.model, self.patch_size = load_model(model_path)
64
+ self.color = hex_to_rgb(color)
65
+ self.opacity = opacity
66
+ self.target_label = target_label
67
+
68
+ def __call__(self, image: Image.Image) -> Image.Image:
69
+ grayscale = image.convert("L")
70
+ original = image.convert("RGB")
71
+ return run_inference(
72
+ model=self.model,
73
+ image=grayscale,
74
+ original=original,
75
+ color=self.color,
76
+ opacity=self.opacity,
77
+ target_label=self.target_label,
78
+ patch_size=self.patch_size
79
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.7.0,<3.0.0
2
+ torchvision>=0.22.0,<0.23.0
3
+ Pillow
4
+ tqdm>=4.67.1,<5.0.0
5
+ scikit-learn>=1.6.1,<2.0.0
6
+ opencv-python>=4.11.0.86,<5.0.0.0
test_pipeline.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from pipeline import HitDetectorPipeline
3
+
4
+ pipe = HitDetectorPipeline("model.pt")
5
+
6
+ img = Image.open("input.png")
7
+ result = pipe(img)
8
+ result.save("output.png")
9
+ print("✅ Output saved to output.png")