csrnet-keras-crowd-counting / INTEGRATION_GUIDE.md
Adiii2308's picture
Upload folder using huggingface_hub
0372a20 verified

CSRNet Model Integration Guide

Quick Integration Steps

1. Copy Required Files

Copy these files to your project:

models/Model.json          # Model architecture
weights/model_A.weights.h5 # Trained weights

2. Install Dependencies

pip install tensorflow keras numpy pillow opencv-python h5py

3. Integration Code

import numpy as np
from PIL import Image
from keras.models import model_from_json

def load_csrnet_model(model_path='models/Model.json', weights_path='weights/model_A.weights.h5'):
    """Load the trained CSRNet model"""
    with open(model_path, 'r') as f:
        model = model_from_json(f.read())
    model.load_weights(weights_path)
    return model

def preprocess_image(image_path):
    """Preprocess image for CSRNet"""
    im = Image.open(image_path).convert('RGB')
    im = np.array(im) / 255.0
    
    # Normalize with ImageNet mean/std
    im[:,:,0] = (im[:,:,0] - 0.485) / 0.229
    im[:,:,1] = (im[:,:,1] - 0.456) / 0.224
    im[:,:,2] = (im[:,:,2] - 0.406) / 0.225
    
    return np.expand_dims(im, axis=0)

def predict_crowd_count(model, image_path):
    """Predict crowd count from image"""
    image = preprocess_image(image_path)
    density_map = model.predict(image)
    count = int(np.sum(density_map))
    return count, density_map

# Usage Example
model = load_csrnet_model()
count, heatmap = predict_crowd_count(model, 'test_image.jpg')
print(f"Predicted crowd count: {count}")

4. API Integration (Flask Example)

from flask import Flask, request, jsonify
import base64
import io

app = Flask(__name__)
model = load_csrnet_model()

@app.route('/predict', methods=['POST'])
def predict():
    file = request.files['image']
    img = Image.open(file.stream)
    img.save('temp.jpg')
    
    count, _ = predict_crowd_count(model, 'temp.jpg')
    return jsonify({'crowd_count': count})

if __name__ == '__main__':
    app.run(port=5000)

5. Web Integration (JavaScript)

async function predictCrowdCount(imageFile) {
    const formData = new FormData();
    formData.append('image', imageFile);
    
    const response = await fetch('http://localhost:5000/predict', {
        method: 'POST',
        body: formData
    });
    
    const result = await response.json();
    console.log('Crowd count:', result.crowd_count);
}

Important Notes

⚠️ Model Accuracy: Current model trained with only 5 epochs Γ— 20 steps. For production:

  • Retrain with 50+ epochs Γ— 200+ steps
  • Or download pre-trained weights from the original repo

File Structure

your_project/
β”œβ”€β”€ models/
β”‚   └── Model.json
β”œβ”€β”€ weights/
β”‚   └── model_A.weights.h5
└── app.py

Performance Tips

  • Load model once at startup (not per request)
  • Use GPU for faster inference
  • Resize large images before prediction
  • Cache model in memory for web apps