| | class LocalClassifier { |
| | constructor() { |
| | this.weights = new Map(); |
| | this.biases = new Map(); |
| | this.learningRate = 0.01; |
| | this.featureDim = 512; |
| | this.isInitialized = false; |
| | } |
| |
|
| | initialize(featureDim = 512) { |
| | this.featureDim = featureDim; |
| | this.isInitialized = true; |
| | } |
| |
|
| | |
| | trainOnFeedback(features, tag, feedback) { |
| | if (!this.isInitialized) { |
| | this.initialize(); |
| | } |
| |
|
| | |
| | let target; |
| | switch (feedback) { |
| | case 'positive': |
| | target = 1.0; |
| | break; |
| | case 'negative': |
| | target = 0.0; |
| | break; |
| | case 'custom': |
| | target = 1.0; |
| | break; |
| | default: |
| | return; |
| | } |
| |
|
| | |
| | if (!this.weights.has(tag)) { |
| | this.weights.set(tag, new Array(this.featureDim).fill(0).map(() => |
| | (Math.random() - 0.5) * 0.01 |
| | )); |
| | this.biases.set(tag, 0); |
| | } |
| |
|
| | const weights = this.weights.get(tag); |
| | const bias = this.biases.get(tag); |
| |
|
| | |
| | let logit = bias; |
| | for (let i = 0; i < features.length; i++) { |
| | logit += weights[i] * features[i]; |
| | } |
| |
|
| | |
| | const prediction = 1 / (1 + Math.exp(-logit)); |
| |
|
| | |
| | const error = prediction - target; |
| | |
| | |
| | for (let i = 0; i < features.length; i++) { |
| | weights[i] -= this.learningRate * error * features[i]; |
| | } |
| | this.biases.set(tag, bias - this.learningRate * error); |
| |
|
| | |
| | this.weights.set(tag, weights); |
| | } |
| |
|
| | |
| | predict(features, tag) { |
| | if (!this.weights.has(tag)) { |
| | return null; |
| | } |
| |
|
| | const weights = this.weights.get(tag); |
| | const bias = this.biases.get(tag); |
| |
|
| | let logit = bias; |
| | for (let i = 0; i < Math.min(features.length, weights.length); i++) { |
| | logit += weights[i] * features[i]; |
| | } |
| |
|
| | |
| | return 1 / (1 + Math.exp(-logit)); |
| | } |
| |
|
| | |
| | predictAll(features, candidateTags) { |
| | const predictions = []; |
| | |
| | for (const tag of candidateTags) { |
| | const confidence = this.predict(features, tag); |
| | if (confidence !== null) { |
| | predictions.push({ tag, confidence }); |
| | } |
| | } |
| |
|
| | return predictions.sort((a, b) => b.confidence - a.confidence); |
| | } |
| |
|
| | |
| | retrainOnBatch(feedbackData) { |
| | for (const item of feedbackData) { |
| | if (item.audioFeatures && item.correctedTags) { |
| | |
| | const features = this.extractSimpleFeatures(item.audioFeatures); |
| | |
| | |
| | for (const tagData of item.correctedTags) { |
| | this.trainOnFeedback(features, tagData.tag, tagData.feedback); |
| | } |
| | } |
| | } |
| | } |
| |
|
| | |
| | extractSimpleFeatures(audioFeatures) { |
| | |
| | |
| | const features = new Array(this.featureDim).fill(0); |
| | |
| | if (audioFeatures) { |
| | |
| | features[0] = audioFeatures.duration / 60; |
| | features[1] = audioFeatures.sampleRate / 48000; |
| | features[2] = audioFeatures.numberOfChannels; |
| | |
| | |
| | const seed = this.simpleHash(JSON.stringify(audioFeatures)); |
| | for (let i = 3; i < this.featureDim; i++) { |
| | features[i] = this.seededRandom(seed + i) * 0.1; |
| | } |
| | } |
| | |
| | return features; |
| | } |
| |
|
| | |
| | simpleHash(str) { |
| | let hash = 0; |
| | for (let i = 0; i < str.length; i++) { |
| | const char = str.charCodeAt(i); |
| | hash = ((hash << 5) - hash) + char; |
| | hash = hash & hash; |
| | } |
| | return Math.abs(hash); |
| | } |
| |
|
| | |
| | seededRandom(seed) { |
| | const x = Math.sin(seed) * 10000; |
| | return x - Math.floor(x); |
| | } |
| |
|
| | |
| | saveModel() { |
| | const modelData = { |
| | weights: Object.fromEntries(this.weights), |
| | biases: Object.fromEntries(this.biases), |
| | featureDim: this.featureDim, |
| | learningRate: this.learningRate |
| | }; |
| | |
| | localStorage.setItem('clipTaggerModel', JSON.stringify(modelData)); |
| | } |
| |
|
| | |
| | loadModel() { |
| | const saved = localStorage.getItem('clipTaggerModel'); |
| | if (saved) { |
| | try { |
| | const modelData = JSON.parse(saved); |
| | this.weights = new Map(Object.entries(modelData.weights)); |
| | this.biases = new Map(Object.entries(modelData.biases)); |
| | this.featureDim = modelData.featureDim || 512; |
| | this.learningRate = modelData.learningRate || 0.01; |
| | this.isInitialized = true; |
| | return true; |
| | } catch (error) { |
| | console.error('Error loading model:', error); |
| | } |
| | } |
| | return false; |
| | } |
| |
|
| | |
| | getModelStats() { |
| | return { |
| | trainedTags: this.weights.size, |
| | featureDim: this.featureDim, |
| | learningRate: this.learningRate, |
| | tags: Array.from(this.weights.keys()) |
| | }; |
| | } |
| |
|
| | |
| | clearModel() { |
| | this.weights.clear(); |
| | this.biases.clear(); |
| | localStorage.removeItem('clipTaggerModel'); |
| | } |
| | } |
| |
|
| | export default LocalClassifier; |