File size: 4,884 Bytes
accf76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
/**
 * WebGPU Inference Wrapper
 * Provides a clean interface between the app and the VL model
 */

import { VLModel, clearModelCache, getCacheInfo, setDebug } from './vl-model.js';
import { getModelConfig } from './config.js';

// Expose debug toggle on window for browser console access
window.setDebug = setDebug;

// Re-export cache utilities
export { clearModelCache, getCacheInfo, setDebug };

export class WebGPUInference {
  constructor() {
    this.model = null;
    this.currentModelId = null;
    this.isLoading = false;
    this.isReady = false;
  }

  /**
   * Load a model
   * @param {string} modelId - Model ID from config
   * @param {object} options - Loading options
   * @param {function} options.progressCallback - Progress callback
   */
  async loadModel(modelId, options = {}) {
    if (this.isLoading) {
      throw new Error('Model is already loading');
    }

    if (this.currentModelId === modelId && this.isReady) {
      return;
    }

    this.isLoading = true;
    this.isReady = false;

    try {
      const modelConfig = getModelConfig(modelId);
      if (!modelConfig) {
        throw new Error(`Model configuration not found: ${modelId}`);
      }

      // Dispose old model if exists
      if (this.model) {
        this.model.dispose();
        this.model = null;
      }

      // Create new model instance
      this.model = new VLModel();

      // Load the model with quantization from config
      await this.model.load(modelConfig.path, {
        device: 'webgpu',
        quantization: modelConfig.quantization || { decoder: null, visionEncoder: null },
        progressCallback: options.progressCallback,
      });

      this.currentModelId = modelId;
      this.isReady = true;

    } catch (error) {
      this.model = null;
      this.currentModelId = null;
      this.isReady = false;
      throw error;
    } finally {
      this.isLoading = false;
    }
  }

  /**
   * Generate a response from messages
   * @param {Array<Object>} messages - Array of message objects with role and content
   * @param {object} options - Generation options
   * @param {function} options.onToken - Token callback for streaming
   * @returns {Promise<string>} Generated response
   */
  async generate(messages, options = {}) {
    if (!this.isReady || !this.model) {
      throw new Error('Model not loaded. Please load a model first.');
    }

    // Convert app message format to VL model format
    const { vlMessages, images, messageImageMap } = this.convertMessages(messages);

    // Generate response
    return await this.model.generate(vlMessages, {
      maxNewTokens: options.maxNewTokens || 512,
      images: images,
      messageImageMap: messageImageMap,
      onToken: options.onToken,
    });
  }

  /**
   * Convert app message format to VL model format
   * @param {Array<Object>} messages - App messages
   * @returns {{vlMessages: Array, images: Array<string>, messageImageMap: Map}}
   */
  convertMessages(messages) {
    const vlMessages = [];
    const images = [];
    const messageImageMap = new Map();

    for (const message of messages) {
      const { role, content } = message;

      if (typeof content === 'string') {
        vlMessages.push({ role, content });
      } else if (Array.isArray(content)) {
        let textContent = '';
        const messageImages = [];

        for (const item of content) {
          if (item.type === 'text') {
            textContent += item.value;
          } else if (item.type === 'image') {
            messageImages.push(item.value);
            images.push(item.value);
          }
        }

        if (textContent.trim() || messageImages.length > 0) {
          if (messageImages.length > 0) {
            messageImageMap.set(vlMessages.length, messageImages);
          }
          vlMessages.push({ role, content: textContent || '' });
        }
      } else {
        vlMessages.push({ role, content: String(content || '') });
      }
    }

    return { vlMessages, images, messageImageMap };
  }

  /**
   * Check if a model is loaded
   * @returns {boolean}
   */
  isModelLoaded() {
    return this.isReady;
  }

  /**
   * Get current model ID
   * @returns {string|null}
   */
  getCurrentModelId() {
    return this.currentModelId;
  }

  /**
   * Clear the image embedding cache
   */
  clearImageCache() {
    if (this.model) {
      this.model.clearImageCache();
    }
  }

  /**
   * Dispose the model and free resources
   */
  dispose() {
    if (this.model) {
      this.model.dispose();
      this.model = null;
    }
    this.currentModelId = null;
    this.isReady = false;
  }
}

// Singleton instance
let webgpuInstance = null;

/**
 * Get the WebGPU inference singleton
 * @returns {WebGPUInference}
 */
export function getWebGPUInference() {
  if (!webgpuInstance) {
    webgpuInstance = new WebGPUInference();
  }
  return webgpuInstance;
}