Vehicle-Damage-Detection / models /damage_classifier.py
neerajkalyank's picture
Update models/damage_classifier.py
8879365 verified
raw
history blame contribute delete
820 Bytes
from transformers import pipeline
from PIL import Image
# Load once (important for Streamlit performance)
_damage_classifier = pipeline(
"image-classification",
model="beingamit99/car_damage_detection"
)
def detect_damage(pil_image: Image.Image) -> dict:
"""
Detect damage type from a vehicle image
Returns:
{
"damage_type": str,
"confidence": float
}
"""
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
predictions = _damage_classifier(pil_image)
if not predictions:
return {
"damage_type": "Unknown",
"confidence": 0.0
}
top_prediction = predictions[0]
return {
"damage_type": top_prediction["label"],
"confidence": round(float(top_prediction["score"]), 3)
}