Keras
vit
roadwork-miner / run.py
MichaelEdward's picture
Upload folder using huggingface_hub
f99a095 verified
#!/usr/bin/env python3
"""
Roadwork-miner: same style as https://huggingface.co/natix-network-org/roadwork.
Combine all weights in this repo; load one model, then predict images.
Usage (one model, then predict):
from run import load_model
model = load_model() # load once (local folder or repo_id)
label = model(image) # 0 or 1
label = model("path/to.jpg") # or pass path
# From HuggingFace repo:
model = load_model(repo_id="MichaelEdward/roadwork-miner")
CLI:
python run.py <path_to_image>
"""
import sys
from pathlib import Path
DIR = Path(__file__).resolve().parent
IMG_SIZE = 224
# Paths (all under this folder when local)
BEST_PT = DIR / "best.pt"
FINAL_MODEL = DIR / "final_output.keras"
VIT_REPO = "natix-network-org/roadwork"
def _image_to_keras_input(image):
"""PIL or numpy (224,224,3) -> (1, 224, 224, 3) normalized for EfficientNet."""
import numpy as np
from PIL import Image
if not isinstance(image, Image.Image):
image = Image.fromarray(np.asarray(image).astype(np.uint8))
if image.size != (IMG_SIZE, IMG_SIZE):
image = image.resize((IMG_SIZE, IMG_SIZE))
if image.mode != "RGB":
image = image.convert("RGB")
arr = np.array(image, dtype=np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
arr = (arr - mean) / std
return np.expand_dims(arr, axis=0)
def _get_vit_prob(pipe, image_pil):
out = pipe(image_pil)
if not isinstance(out, list):
out = [out]
for item in out:
if item.get("label") == "Roadwork":
return item["score"]
return 0.0
def _get_yolo_prob(yolo_model, image_pil, roadwork_idx=1):
import numpy as np
r = yolo_model.predict(source=image_pil, verbose=False, device="cpu")
if not r or not hasattr(r[0], "probs") or r[0].probs is None:
return 0.0
p = r[0].probs.data
if hasattr(p, "cpu"):
p = p.cpu().numpy()
else:
p = np.asarray(p)
if p.ndim > 1:
p = p.ravel()
idx = min(roadwork_idx, len(p) - 1)
return float(p[idx])
def load_pipeline(best_pt=None, final_model_path=None):
"""Load ViT, YOLO, and final_output.keras. Returns dict with vit, yolo, model."""
from tensorflow import keras
from transformers import AutoImageProcessor, AutoModelForImageClassification, pipeline
from ultralytics import YOLO
best_pt = Path(best_pt) if best_pt else BEST_PT
final_model_path = Path(final_model_path) if final_model_path else FINAL_MODEL
if not final_model_path.exists():
raise FileNotFoundError(f"final_output.keras not found at {final_model_path}")
if not best_pt.exists():
raise FileNotFoundError(f"best.pt not found at {best_pt}")
model = keras.models.load_model(final_model_path)
pipe = pipeline(
"image-classification",
model=AutoModelForImageClassification.from_pretrained(VIT_REPO),
feature_extractor=AutoImageProcessor.from_pretrained(VIT_REPO, use_fast=True),
device=-1,
)
yolo = YOLO(str(best_pt))
return {"vit": pipe, "yolo": yolo, "model": model}
class RoadworkModel:
"""One model: load once, then call with image(s). Same style as natix-network-org/roadwork."""
def __init__(self, pipeline, threshold=0.5):
self._pipeline = pipeline
self.threshold = threshold
def __call__(self, image):
"""Predict 0 or 1. image: PIL Image, numpy array (224,224,3), or path to image file."""
from PIL import Image
if isinstance(image, (str, Path)):
image = load_image(image)
return predict(image, self._pipeline, threshold=self.threshold)
def load_model(repo_id=None, threshold=0.5):
"""
Load one model (all weights combined). Then use model(image) to predict.
Same style as reference: https://huggingface.co/natix-network-org/roadwork
repo_id: None = use local folder (this repo on disk). Else e.g. "MichaelEdward/roadwork-miner"
to download and load from HuggingFace.
Returns a callable: model(image) -> 0 or 1.
"""
if repo_id:
from huggingface_hub import snapshot_download
cache = Path(snapshot_download(repo_id=repo_id, local_dir_use_symlinks=False))
best_pt = cache / "best.pt"
final_model_path = cache / "final_output.keras"
pipeline = load_pipeline(best_pt=best_pt, final_model_path=final_model_path)
else:
pipeline = load_pipeline()
return RoadworkModel(pipeline, threshold=threshold)
def predict(image, pipeline, threshold=0.5):
"""
Predict 0 or 1 from one image (PIL or numpy 224x224x3).
pipeline: from load_pipeline().
"""
import numpy as np
from PIL import Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype(np.uint8) if image.ndim == 3 else image[0].astype(np.uint8))
if image.size != (IMG_SIZE, IMG_SIZE):
image = image.resize((IMG_SIZE, IMG_SIZE))
if image.mode != "RGB":
image = image.convert("RGB")
p_vit = _get_vit_prob(pipeline["vit"], image)
p_yolo = _get_yolo_prob(pipeline["yolo"], image)
X_img = _image_to_keras_input(image)
p_vit_arr = np.array([[float(p_vit)]], dtype=np.float32)
p_yolo_arr = np.array([[float(p_yolo)]], dtype=np.float32)
prob = pipeline["model"].predict([X_img, p_vit_arr, p_yolo_arr], verbose=0)
roadwork_prob = float(prob[0, 0])
return 1 if roadwork_prob >= threshold else 0
def make_demo_image(size=IMG_SIZE):
"""Create a 224x224 RGB image in code (no file). Simple gradient for demo."""
import numpy as np
from PIL import Image
y = np.linspace(0, 1, size).reshape(size, 1)
x = np.linspace(0, 1, size).reshape(1, size)
r = (0.4 + 0.2 * x).clip(0, 1) # (1, size) -> broadcast
g = (0.5 + 0.2 * y).clip(0, 1) # (size, 1) -> broadcast
b = (0.45 + 0.1 * (x + y)).clip(0, 1) # (size, size)
r = np.broadcast_to(r, (size, size))
g = np.broadcast_to(g, (size, size))
arr = np.stack([r, g, b], axis=-1)
arr = (arr * 255).astype(np.uint8)
return Image.fromarray(arr, mode="RGB")
def load_image(path):
"""Load an image from file. Returns PIL Image (RGB). Tries cwd, then script directory (DIR)."""
from PIL import Image
raw = Path(str(path).strip()).expanduser()
# 1) Relative to current working directory
p = raw.resolve()
if p.exists():
return Image.open(p).convert("RGB")
# 2) Fallback: same filename in script directory (e.g. roadwork-miner/test_image.jpg)
p = DIR / raw.name
if p.exists():
return Image.open(p).convert("RGB")
raise FileNotFoundError(f"Image not found: {raw} (tried cwd and {DIR})")
def main():
if len(sys.argv) < 2:
print("Usage: python run.py <path_to_image>")
sys.exit(1)
try:
input_image = load_image(sys.argv[1].strip())
except FileNotFoundError as e:
print(e)
sys.exit(1)
pipeline = load_pipeline()
label = predict(input_image, pipeline, threshold=0.5)
print(f"Prediction: {label} (0=no roadwork, 1=roadwork)")
return label
if __name__ == "__main__":
main()