File size: 2,496 Bytes
ed9f15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
/**

 * Model Loading - Load and initialize ONNX models

 */

import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/esm/ort.min.js';
import { MODEL_PATHS, ONNX_CONFIG, LABEL_MAPPINGS } from '../config.js';
import { setModel, appState } from '../state.js';
import { updateLoadingStatus } from '../ui/loading.js';

/**

 * Initialize ONNX Runtime

 */
export async function initONNXRuntime() {
    ort.env.wasm.wasmPaths = ONNX_CONFIG.wasmPaths;
    ort.env.wasm.numThreads = ONNX_CONFIG.numThreads;
}

/**

 * Load label mappings from config files

 */
export async function loadLabelMappings() {
    try {
        const graderConfig = await fetch('./models/grader_model_compressed/config.json');
        const graderData = await graderConfig.json();
        LABEL_MAPPINGS.grader = graderData.id2label || {};
        
        console.log('Label mappings loaded:', LABEL_MAPPINGS);
    } catch (error) {
        console.warn('Could not load label mappings, using defaults:', error);
    }
}

/**

 * Load all ONNX models

 */
export async function loadAllModels() {
    try {
        updateLoadingStatus('Loading Classifier Model...', 25);
        const classifier = await ort.InferenceSession.create(MODEL_PATHS.classifier);
        setModel('classifier', classifier);
        
        updateLoadingStatus('Loading Quality Assessment Model...', 50);
        const poorGood = await ort.InferenceSession.create(MODEL_PATHS.poorGood);
        setModel('poorGood', poorGood);
        
        updateLoadingStatus('Loading Grader Model...', 75);
        const grader = await ort.InferenceSession.create(MODEL_PATHS.grader);
        setModel('grader', grader);
        
        updateLoadingStatus('Loading YOLO Detection Model...', 90);
        try {
            const yolo = await ort.InferenceSession.create(MODEL_PATHS.yolo);
            setModel('yolo', yolo);
        } catch (error) {
            console.warn('YOLO model not available:', error);
        }
        
        updateLoadingStatus('Models loaded successfully!', 100);
    } catch (error) {
        throw new Error(`Model loading failed: ${error.message}`);
    }
}

/**

 * Get model status for UI display

 */
export function getModelStatus() {
    return {
        classifier: !!appState.models.classifier,
        poorGood: !!appState.models.poorGood,
        grader: !!appState.models.grader,
        yolo: !!appState.models.yolo
    };
}