import React, { useState } from 'react'; import * as tflite from '@tensorflow/tfjs-tflite'; function TFLiteObjectDetection() { const [averageTime, setAverageTime] = useState(null); const [loading, setLoading] = useState(false); const [images, setImages] = useState([]); const [model, setModel] = useState(null); const handleFileChange = (event) => { const files = Array.from(event.target.files); setImages(files.slice(0, 10)); // Limit to the first 10 images }; const loadModel = async () => { try { // Load the TFLite model const loadedModel = await tflite.loadTFLiteModel('./model.tflite'); setModel(loadedModel); console.log('Model loaded successfully!'); } catch (error) { console.error('Error loading TFLite model:', error); } }; const runBenchmark = async () => { if (!model || images.length === 0) { alert('Please load the model and upload 10 images.'); return; } setLoading(true); const repetitions = 50; // Number of repetitions for benchmarking let totalInferenceTime = 0; try { for (let rep = 0; rep < repetitions; rep++) { console.log(`Repetition ${rep + 1} of ${repetitions}`); for (const imageFile of images) { const startTime = performance.now(); // Preprocess the image to create a tensor const inputTensor = await preprocessImage(imageFile); // Run inference using the TFLite model const output = model.predict(inputTensor); const endTime = performance.now(); totalInferenceTime += endTime - startTime; // Log output for debugging (optional) console.log('Inference output:', output); } } // Calculate average inference time const avgInferenceTime = totalInferenceTime / (repetitions * images.length); setAverageTime(avgInferenceTime); } catch (error) { console.error('Error during inference:', error); } setLoading(false); }; const preprocessImage = async (imageFile) => { return new Promise((resolve) => { const img = new Image(); const reader = new FileReader(); reader.onload = () => { img.src = reader.result; }; img.onload = () => { const canvas = document.createElement('canvas'); const context = canvas.getContext('2d'); // Resize to match model input size const modelInputWidth = 320; // Replace with your model's input width const modelInputHeight = 320; // Replace with your model's input height canvas.width = modelInputWidth; canvas.height = modelInputHeight; context.drawImage(img, 0, 0, modelInputWidth, modelInputHeight); const imageData = context.getImageData(0, 0, modelInputWidth, modelInputHeight); // Normalize pixel values to [0, 1] and convert to Float32Array const floatData = new Float32Array(imageData.data.length / 4); for (let i = 0, j = 0; i < imageData.data.length; i += 4) { floatData[j++] = imageData.data[i] / 255; // R floatData[j++] = imageData.data[i + 1] / 255; // G floatData[j++] = imageData.data[i + 2] / 255; // B } // Create a tensor with shape [1, 320, 320, 3] resolve(new tflite.Tensor(floatData, [1, modelInputHeight, modelInputWidth, 3])); }; reader.readAsDataURL(imageFile); }); }; return React.createElement( 'div', null, React.createElement('h1', null, 'Object Detection Benchmark (TFLite)'), React.createElement('button', { onClick: loadModel, disabled: model !== null }, 'Load Model'), React.createElement('input', { type: 'file', multiple: true, accept: 'image/*', onChange: handleFileChange, }), React.createElement( 'button', { onClick: runBenchmark, disabled: loading || !model || images.length === 0 }, loading ? 'Running Benchmark...' : 'Start Benchmark' ), React.createElement( 'div', null, averageTime !== null ? React.createElement( 'h2', null, `Average Inference Time: ${averageTime.toFixed(2)} ms` ) : null ), React.createElement( 'ul', null, images.map((img, index) => React.createElement('li', { key: index }, img.name) ) ) ); } export default TFLiteObjectDetection;