Spaces:
Running
Running
Implement adaptive local classifier that learns from user feedback
Browse files- src/App.css +13 -0
- src/App.jsx +78 -5
- src/localClassifier.js +205 -0
src/App.css
CHANGED
|
@@ -147,6 +147,19 @@ header p {
|
|
| 147 |
opacity: 0.6;
|
| 148 |
}
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
.tag-controls {
|
| 151 |
display: flex;
|
| 152 |
gap: 0.25rem;
|
|
|
|
| 147 |
opacity: 0.6;
|
| 148 |
}
|
| 149 |
|
| 150 |
+
.tag.local {
|
| 151 |
+
background: linear-gradient(45deg, #9b59b6, #8e44ad);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.tag.blended {
|
| 155 |
+
background: linear-gradient(45deg, #f39c12, #e67e22);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
.source-indicator {
|
| 159 |
+
margin-left: 0.5rem;
|
| 160 |
+
font-size: 0.8em;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
.tag-controls {
|
| 164 |
display: flex;
|
| 165 |
gap: 0.25rem;
|
src/App.jsx
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import { useState, useRef, useEffect } from 'react'
|
| 2 |
import CLAPProcessor from './clapProcessor'
|
| 3 |
import UserFeedbackStore from './userFeedbackStore'
|
|
|
|
| 4 |
import './App.css'
|
| 5 |
|
| 6 |
function App() {
|
|
@@ -18,11 +19,16 @@ function App() {
|
|
| 18 |
const chunksRef = useRef([])
|
| 19 |
const clapProcessorRef = useRef(null)
|
| 20 |
const feedbackStoreRef = useRef(null)
|
|
|
|
| 21 |
|
| 22 |
useEffect(() => {
|
| 23 |
const initializeStore = async () => {
|
| 24 |
feedbackStoreRef.current = new UserFeedbackStore()
|
| 25 |
await feedbackStoreRef.current.initialize()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
loadCustomTags()
|
| 27 |
}
|
| 28 |
initializeStore()
|
|
@@ -107,13 +113,53 @@ function App() {
|
|
| 107 |
const generatedTags = await clapProcessorRef.current.processAudio(audioBuffer)
|
| 108 |
|
| 109 |
// Store basic audio info for later use
|
| 110 |
-
|
| 111 |
sampleRate: audioBuffer.sampleRate,
|
| 112 |
duration: audioBuffer.duration,
|
| 113 |
numberOfChannels: audioBuffer.numberOfChannels
|
| 114 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
setTags(
|
| 117 |
} catch (err) {
|
| 118 |
console.error('Error processing audio:', err)
|
| 119 |
setError('Failed to process audio. Using fallback tags.')
|
|
@@ -139,6 +185,17 @@ function App() {
|
|
| 139 |
feedback,
|
| 140 |
audioHash
|
| 141 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
} catch (error) {
|
| 143 |
console.error('Error saving tag feedback:', error)
|
| 144 |
}
|
|
@@ -151,7 +208,8 @@ function App() {
|
|
| 151 |
label: newTag.trim(),
|
| 152 |
confidence: 1.0,
|
| 153 |
userFeedback: 'custom',
|
| 154 |
-
isCustom: true
|
|
|
|
| 155 |
}
|
| 156 |
|
| 157 |
setTags(prev => [...prev, customTag])
|
|
@@ -159,6 +217,18 @@ function App() {
|
|
| 159 |
try {
|
| 160 |
await feedbackStoreRef.current.saveCustomTag(newTag.trim())
|
| 161 |
await feedbackStoreRef.current.saveTagFeedback(newTag.trim(), 'custom', audioHash)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
loadCustomTags()
|
| 163 |
} catch (error) {
|
| 164 |
console.error('Error saving custom tag:', error)
|
|
@@ -236,8 +306,11 @@ function App() {
|
|
| 236 |
<div className="tags">
|
| 237 |
{tags.map((tag, index) => (
|
| 238 |
<div key={index} className={`tag-item ${tag.userFeedback ? 'has-feedback' : ''}`}>
|
| 239 |
-
<span className={`tag ${tag.isCustom ? 'custom' : ''} ${tag.userFeedback === 'negative' ? 'negative' : ''}`}>
|
| 240 |
{tag.label} ({Math.round(tag.confidence * 100)}%)
|
|
|
|
|
|
|
|
|
|
| 241 |
</span>
|
| 242 |
{!tag.isCustom && (
|
| 243 |
<div className="tag-controls">
|
|
|
|
| 1 |
import { useState, useRef, useEffect } from 'react'
|
| 2 |
import CLAPProcessor from './clapProcessor'
|
| 3 |
import UserFeedbackStore from './userFeedbackStore'
|
| 4 |
+
import LocalClassifier from './localClassifier'
|
| 5 |
import './App.css'
|
| 6 |
|
| 7 |
function App() {
|
|
|
|
| 19 |
const chunksRef = useRef([])
|
| 20 |
const clapProcessorRef = useRef(null)
|
| 21 |
const feedbackStoreRef = useRef(null)
|
| 22 |
+
const localClassifierRef = useRef(null)
|
| 23 |
|
| 24 |
useEffect(() => {
|
| 25 |
const initializeStore = async () => {
|
| 26 |
feedbackStoreRef.current = new UserFeedbackStore()
|
| 27 |
await feedbackStoreRef.current.initialize()
|
| 28 |
+
|
| 29 |
+
localClassifierRef.current = new LocalClassifier()
|
| 30 |
+
localClassifierRef.current.loadModel()
|
| 31 |
+
|
| 32 |
loadCustomTags()
|
| 33 |
}
|
| 34 |
initializeStore()
|
|
|
|
| 113 |
const generatedTags = await clapProcessorRef.current.processAudio(audioBuffer)
|
| 114 |
|
| 115 |
// Store basic audio info for later use
|
| 116 |
+
const features = {
|
| 117 |
sampleRate: audioBuffer.sampleRate,
|
| 118 |
duration: audioBuffer.duration,
|
| 119 |
numberOfChannels: audioBuffer.numberOfChannels
|
| 120 |
+
}
|
| 121 |
+
setAudioFeatures(features)
|
| 122 |
+
|
| 123 |
+
// Apply local classifier adjustments
|
| 124 |
+
let finalTags = generatedTags.map(tag => ({ ...tag, userFeedback: null }))
|
| 125 |
+
|
| 126 |
+
if (localClassifierRef.current) {
|
| 127 |
+
const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(features)
|
| 128 |
+
const allPossibleTags = [...generatedTags.map(t => t.label), ...customTags]
|
| 129 |
+
const localPredictions = localClassifierRef.current.predictAll(simpleFeatures, allPossibleTags)
|
| 130 |
+
|
| 131 |
+
// Merge CLAP predictions with local classifier predictions
|
| 132 |
+
const mergedTags = new Map()
|
| 133 |
+
|
| 134 |
+
// Add CLAP tags
|
| 135 |
+
for (const tag of generatedTags) {
|
| 136 |
+
mergedTags.set(tag.label, { ...tag, source: 'clap' })
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Add or adjust with local predictions
|
| 140 |
+
for (const pred of localPredictions) {
|
| 141 |
+
if (mergedTags.has(pred.tag)) {
|
| 142 |
+
// Blend CLAP and local predictions
|
| 143 |
+
const existing = mergedTags.get(pred.tag)
|
| 144 |
+
existing.confidence = (existing.confidence + pred.confidence) / 2
|
| 145 |
+
existing.source = 'blended'
|
| 146 |
+
} else if (pred.confidence > 0.6) {
|
| 147 |
+
// Add high-confidence local predictions
|
| 148 |
+
mergedTags.set(pred.tag, {
|
| 149 |
+
label: pred.tag,
|
| 150 |
+
confidence: pred.confidence,
|
| 151 |
+
source: 'local',
|
| 152 |
+
userFeedback: null
|
| 153 |
+
})
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
finalTags = Array.from(mergedTags.values())
|
| 158 |
+
.sort((a, b) => b.confidence - a.confidence)
|
| 159 |
+
.slice(0, 8) // Keep top 8 tags
|
| 160 |
+
}
|
| 161 |
|
| 162 |
+
setTags(finalTags)
|
| 163 |
} catch (err) {
|
| 164 |
console.error('Error processing audio:', err)
|
| 165 |
setError('Failed to process audio. Using fallback tags.')
|
|
|
|
| 185 |
feedback,
|
| 186 |
audioHash
|
| 187 |
)
|
| 188 |
+
|
| 189 |
+
// Train local classifier on this feedback
|
| 190 |
+
if (localClassifierRef.current && audioFeatures) {
|
| 191 |
+
const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(audioFeatures)
|
| 192 |
+
localClassifierRef.current.trainOnFeedback(
|
| 193 |
+
simpleFeatures,
|
| 194 |
+
updatedTags[tagIndex].label,
|
| 195 |
+
feedback
|
| 196 |
+
)
|
| 197 |
+
localClassifierRef.current.saveModel()
|
| 198 |
+
}
|
| 199 |
} catch (error) {
|
| 200 |
console.error('Error saving tag feedback:', error)
|
| 201 |
}
|
|
|
|
| 208 |
label: newTag.trim(),
|
| 209 |
confidence: 1.0,
|
| 210 |
userFeedback: 'custom',
|
| 211 |
+
isCustom: true,
|
| 212 |
+
source: 'custom'
|
| 213 |
}
|
| 214 |
|
| 215 |
setTags(prev => [...prev, customTag])
|
|
|
|
| 217 |
try {
|
| 218 |
await feedbackStoreRef.current.saveCustomTag(newTag.trim())
|
| 219 |
await feedbackStoreRef.current.saveTagFeedback(newTag.trim(), 'custom', audioHash)
|
| 220 |
+
|
| 221 |
+
// Train local classifier on custom tag
|
| 222 |
+
if (localClassifierRef.current && audioFeatures) {
|
| 223 |
+
const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(audioFeatures)
|
| 224 |
+
localClassifierRef.current.trainOnFeedback(
|
| 225 |
+
simpleFeatures,
|
| 226 |
+
newTag.trim(),
|
| 227 |
+
'custom'
|
| 228 |
+
)
|
| 229 |
+
localClassifierRef.current.saveModel()
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
loadCustomTags()
|
| 233 |
} catch (error) {
|
| 234 |
console.error('Error saving custom tag:', error)
|
|
|
|
| 306 |
<div className="tags">
|
| 307 |
{tags.map((tag, index) => (
|
| 308 |
<div key={index} className={`tag-item ${tag.userFeedback ? 'has-feedback' : ''}`}>
|
| 309 |
+
<span className={`tag ${tag.isCustom ? 'custom' : ''} ${tag.userFeedback === 'negative' ? 'negative' : ''} ${tag.source || 'clap'}`}>
|
| 310 |
{tag.label} ({Math.round(tag.confidence * 100)}%)
|
| 311 |
+
{tag.source === 'local' && <span className="source-indicator">🧠</span>}
|
| 312 |
+
{tag.source === 'blended' && <span className="source-indicator">⚡</span>}
|
| 313 |
+
{tag.source === 'custom' && <span className="source-indicator">✨</span>}
|
| 314 |
</span>
|
| 315 |
{!tag.isCustom && (
|
| 316 |
<div className="tag-controls">
|
src/localClassifier.js
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class LocalClassifier {
|
| 2 |
+
constructor() {
|
| 3 |
+
this.weights = new Map(); // tag -> weight vector
|
| 4 |
+
this.biases = new Map(); // tag -> bias
|
| 5 |
+
this.learningRate = 0.01;
|
| 6 |
+
this.featureDim = 512; // CLAP embedding dimension
|
| 7 |
+
this.isInitialized = false;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
initialize(featureDim = 512) {
|
| 11 |
+
this.featureDim = featureDim;
|
| 12 |
+
this.isInitialized = true;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
// Simple logistic regression training
|
| 16 |
+
trainOnFeedback(features, tag, feedback) {
|
| 17 |
+
if (!this.isInitialized) {
|
| 18 |
+
this.initialize();
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// Convert feedback to target value
|
| 22 |
+
let target;
|
| 23 |
+
switch (feedback) {
|
| 24 |
+
case 'positive':
|
| 25 |
+
target = 1.0;
|
| 26 |
+
break;
|
| 27 |
+
case 'negative':
|
| 28 |
+
target = 0.0;
|
| 29 |
+
break;
|
| 30 |
+
case 'custom':
|
| 31 |
+
target = 1.0;
|
| 32 |
+
break;
|
| 33 |
+
default:
|
| 34 |
+
return; // Skip unknown feedback
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Initialize weights for new tag
|
| 38 |
+
if (!this.weights.has(tag)) {
|
| 39 |
+
this.weights.set(tag, new Array(this.featureDim).fill(0).map(() =>
|
| 40 |
+
(Math.random() - 0.5) * 0.01
|
| 41 |
+
));
|
| 42 |
+
this.biases.set(tag, 0);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
const weights = this.weights.get(tag);
|
| 46 |
+
const bias = this.biases.get(tag);
|
| 47 |
+
|
| 48 |
+
// Forward pass
|
| 49 |
+
let logit = bias;
|
| 50 |
+
for (let i = 0; i < features.length; i++) {
|
| 51 |
+
logit += weights[i] * features[i];
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// Sigmoid activation
|
| 55 |
+
const prediction = 1 / (1 + Math.exp(-logit));
|
| 56 |
+
|
| 57 |
+
// Compute gradient
|
| 58 |
+
const error = prediction - target;
|
| 59 |
+
|
| 60 |
+
// Update weights and bias
|
| 61 |
+
for (let i = 0; i < features.length; i++) {
|
| 62 |
+
weights[i] -= this.learningRate * error * features[i];
|
| 63 |
+
}
|
| 64 |
+
this.biases.set(tag, bias - this.learningRate * error);
|
| 65 |
+
|
| 66 |
+
// Store updated weights
|
| 67 |
+
this.weights.set(tag, weights);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
// Predict confidence for a tag given features
|
| 71 |
+
predict(features, tag) {
|
| 72 |
+
if (!this.weights.has(tag)) {
|
| 73 |
+
return null; // No training data for this tag
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
const weights = this.weights.get(tag);
|
| 77 |
+
const bias = this.biases.get(tag);
|
| 78 |
+
|
| 79 |
+
let logit = bias;
|
| 80 |
+
for (let i = 0; i < Math.min(features.length, weights.length); i++) {
|
| 81 |
+
logit += weights[i] * features[i];
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Sigmoid activation
|
| 85 |
+
return 1 / (1 + Math.exp(-logit));
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// Get all predictions for given features
|
| 89 |
+
predictAll(features, candidateTags) {
|
| 90 |
+
const predictions = [];
|
| 91 |
+
|
| 92 |
+
for (const tag of candidateTags) {
|
| 93 |
+
const confidence = this.predict(features, tag);
|
| 94 |
+
if (confidence !== null) {
|
| 95 |
+
predictions.push({ tag, confidence });
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return predictions.sort((a, b) => b.confidence - a.confidence);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// Retrain on batch of feedback data
|
| 103 |
+
retrainOnBatch(feedbackData) {
|
| 104 |
+
for (const item of feedbackData) {
|
| 105 |
+
if (item.audioFeatures && item.correctedTags) {
|
| 106 |
+
// Create simple features from audio metadata
|
| 107 |
+
const features = this.extractSimpleFeatures(item.audioFeatures);
|
| 108 |
+
|
| 109 |
+
// Train on corrected tags
|
| 110 |
+
for (const tagData of item.correctedTags) {
|
| 111 |
+
this.trainOnFeedback(features, tagData.tag, tagData.feedback);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// Extract simple features from audio metadata
|
| 118 |
+
extractSimpleFeatures(audioFeatures) {
|
| 119 |
+
// Create a simple feature vector from audio metadata
|
| 120 |
+
// In a real implementation, this would use actual CLAP embeddings
|
| 121 |
+
const features = new Array(this.featureDim).fill(0);
|
| 122 |
+
|
| 123 |
+
if (audioFeatures) {
|
| 124 |
+
// Use basic audio properties to create pseudo-features
|
| 125 |
+
features[0] = audioFeatures.duration / 60; // Duration in minutes
|
| 126 |
+
features[1] = audioFeatures.sampleRate / 48000; // Normalized sample rate
|
| 127 |
+
features[2] = audioFeatures.numberOfChannels; // Number of channels
|
| 128 |
+
|
| 129 |
+
// Fill remaining with small random values based on hash of properties
|
| 130 |
+
const seed = this.simpleHash(JSON.stringify(audioFeatures));
|
| 131 |
+
for (let i = 3; i < this.featureDim; i++) {
|
| 132 |
+
features[i] = this.seededRandom(seed + i) * 0.1;
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
return features;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Simple hash function for seeded random
|
| 140 |
+
simpleHash(str) {
|
| 141 |
+
let hash = 0;
|
| 142 |
+
for (let i = 0; i < str.length; i++) {
|
| 143 |
+
const char = str.charCodeAt(i);
|
| 144 |
+
hash = ((hash << 5) - hash) + char;
|
| 145 |
+
hash = hash & hash; // Convert to 32-bit integer
|
| 146 |
+
}
|
| 147 |
+
return Math.abs(hash);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// Seeded random number generator
|
| 151 |
+
seededRandom(seed) {
|
| 152 |
+
const x = Math.sin(seed) * 10000;
|
| 153 |
+
return x - Math.floor(x);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// Save model to localStorage
|
| 157 |
+
saveModel() {
|
| 158 |
+
const modelData = {
|
| 159 |
+
weights: Object.fromEntries(this.weights),
|
| 160 |
+
biases: Object.fromEntries(this.biases),
|
| 161 |
+
featureDim: this.featureDim,
|
| 162 |
+
learningRate: this.learningRate
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
localStorage.setItem('clipTaggerModel', JSON.stringify(modelData));
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// Load model from localStorage
|
| 169 |
+
loadModel() {
|
| 170 |
+
const saved = localStorage.getItem('clipTaggerModel');
|
| 171 |
+
if (saved) {
|
| 172 |
+
try {
|
| 173 |
+
const modelData = JSON.parse(saved);
|
| 174 |
+
this.weights = new Map(Object.entries(modelData.weights));
|
| 175 |
+
this.biases = new Map(Object.entries(modelData.biases));
|
| 176 |
+
this.featureDim = modelData.featureDim || 512;
|
| 177 |
+
this.learningRate = modelData.learningRate || 0.01;
|
| 178 |
+
this.isInitialized = true;
|
| 179 |
+
return true;
|
| 180 |
+
} catch (error) {
|
| 181 |
+
console.error('Error loading model:', error);
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
return false;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// Get model statistics
|
| 188 |
+
getModelStats() {
|
| 189 |
+
return {
|
| 190 |
+
trainedTags: this.weights.size,
|
| 191 |
+
featureDim: this.featureDim,
|
| 192 |
+
learningRate: this.learningRate,
|
| 193 |
+
tags: Array.from(this.weights.keys())
|
| 194 |
+
};
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// Clear the model
|
| 198 |
+
clearModel() {
|
| 199 |
+
this.weights.clear();
|
| 200 |
+
this.biases.clear();
|
| 201 |
+
localStorage.removeItem('clipTaggerModel');
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
export default LocalClassifier;
|