popkek00's picture
Create README.md
add0ce0 verified
---
base_model: microsoft/resnet-18
license: mit
tags:
- image-classification
- pytorch
- computer-vision
- fall-detection
---
# Fall Detection Model (ResNet-18 Fine-tuned)
This model is a fine-tuned ResNet-18 for image classification, specifically trained to detect falls in images.
## Model Details
- **Base Model:** `microsoft/resnet-18`
- **Dataset:** `hiennguyen9874/fall-detection-dataset`
- **Task:** Binary image classification (fall/no_fall)
- **Classes:**
- `0`: `no_fall`
- `1`: `fall`
## How to Use
### 1. Load the Model and Image Processor
```python
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import torch
# Assuming 'device' is already defined (e.g., torch.device("cuda" if torch.cuda.is_available() else "cpu"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo_id = "popkek00/fall_detection_model" # Your model's repository ID
model = AutoModelForImageClassification.from_pretrained(repo_id).to(device)
image_processor = AutoImageProcessor.from_pretrained(repo_id)
model.eval() # Set model to evaluation mode
```
### 2. Prepare an Image for Inference
```python
# Example: Load an image (replace with your image path or PIL Image object)
# You can load an image from a URL, local file, or a BytesIO object
# For demonstration, let's assume you have a PIL Image object called `example_image`
# Create a dummy image for demonstration
example_image = Image.new('RGB', (224, 224), color = 'red')
# Process the image
inputs = image_processor(images=example_image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
```
### 3. Get Predictions
```python
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class_id = probabilities.argmax().item()
# Get the human-readable label from the model's config
predicted_label = model.config.id2label[predicted_class_id]
confidence = probabilities[0, predicted_class_id].item() * 100
print(f"Predicted label: {predicted_label} (Confidence: {confidence:.2f}%)")
```
---