/** * WebGPU Inference Wrapper * Provides a clean interface between the app and the VL model */ import { VLModel, clearModelCache, getCacheInfo, setDebug } from './vl-model.js'; import { getModelConfig } from './config.js'; // Expose debug toggle on window for browser console access window.setDebug = setDebug; // Re-export cache utilities export { clearModelCache, getCacheInfo, setDebug }; export class WebGPUInference { constructor() { this.model = null; this.currentModelId = null; this.isLoading = false; this.isReady = false; } /** * Load a model * @param {string} modelId - Model ID from config * @param {object} options - Loading options * @param {function} options.progressCallback - Progress callback */ async loadModel(modelId, options = {}) { if (this.isLoading) { throw new Error('Model is already loading'); } if (this.currentModelId === modelId && this.isReady) { return; } this.isLoading = true; this.isReady = false; try { const modelConfig = getModelConfig(modelId); if (!modelConfig) { throw new Error(`Model configuration not found: ${modelId}`); } // Dispose old model if exists if (this.model) { this.model.dispose(); this.model = null; } // Create new model instance this.model = new VLModel(); // Load the model with quantization from config await this.model.load(modelConfig.path, { device: 'webgpu', quantization: modelConfig.quantization || { decoder: null, visionEncoder: null }, progressCallback: options.progressCallback, }); this.currentModelId = modelId; this.isReady = true; } catch (error) { this.model = null; this.currentModelId = null; this.isReady = false; throw error; } finally { this.isLoading = false; } } /** * Generate a response from messages * @param {Array} messages - Array of message objects with role and content * @param {object} options - Generation options * @param {function} options.onToken - Token callback for streaming * @returns {Promise} Generated response */ async generate(messages, options = {}) { if (!this.isReady || !this.model) { throw new Error('Model not loaded. Please load a model first.'); } // Convert app message format to VL model format const { vlMessages, images, messageImageMap } = this.convertMessages(messages); // Generate response return await this.model.generate(vlMessages, { maxNewTokens: options.maxNewTokens || 512, images: images, messageImageMap: messageImageMap, onToken: options.onToken, }); } /** * Convert app message format to VL model format * @param {Array} messages - App messages * @returns {{vlMessages: Array, images: Array, messageImageMap: Map}} */ convertMessages(messages) { const vlMessages = []; const images = []; const messageImageMap = new Map(); for (const message of messages) { const { role, content } = message; if (typeof content === 'string') { vlMessages.push({ role, content }); } else if (Array.isArray(content)) { let textContent = ''; const messageImages = []; for (const item of content) { if (item.type === 'text') { textContent += item.value; } else if (item.type === 'image') { messageImages.push(item.value); images.push(item.value); } } if (textContent.trim() || messageImages.length > 0) { if (messageImages.length > 0) { messageImageMap.set(vlMessages.length, messageImages); } vlMessages.push({ role, content: textContent || '' }); } } else { vlMessages.push({ role, content: String(content || '') }); } } return { vlMessages, images, messageImageMap }; } /** * Check if a model is loaded * @returns {boolean} */ isModelLoaded() { return this.isReady; } /** * Get current model ID * @returns {string|null} */ getCurrentModelId() { return this.currentModelId; } /** * Clear the image embedding cache */ clearImageCache() { if (this.model) { this.model.clearImageCache(); } } /** * Dispose the model and free resources */ dispose() { if (this.model) { this.model.dispose(); this.model = null; } this.currentModelId = null; this.isReady = false; } } // Singleton instance let webgpuInstance = null; /** * Get the WebGPU inference singleton * @returns {WebGPUInference} */ export function getWebGPUInference() { if (!webgpuInstance) { webgpuInstance = new WebGPUInference(); } return webgpuInstance; }