diff --git a/src/app/api/auth/hf/callback/route.ts b/src/app/api/auth/hf/callback/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..2e353f3149ff687b697cb541e787a42442de501a --- /dev/null +++ b/src/app/api/auth/hf/callback/route.ts @@ -0,0 +1,112 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { cookies } from 'next/headers'; + +const TOKEN_ENDPOINT = 'https://huggingface.co/oauth/token'; +const USERINFO_ENDPOINT = 'https://huggingface.co/oauth/userinfo'; +const STATE_COOKIE = 'hf_oauth_state'; + +function htmlResponse(script: string) { + return new NextResponse( + ``, + { + headers: { 'Content-Type': 'text/html; charset=utf-8' }, + }, + ); +} + +export async function GET(request: NextRequest) { + const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID; + const clientSecret = process.env.HF_OAUTH_CLIENT_SECRET; + + if (!clientId || !clientSecret) { + return NextResponse.json({ error: 'OAuth application is not configured' }, { status: 500 }); + } + + const { searchParams } = new URL(request.url); + const code = searchParams.get('code'); + const incomingState = searchParams.get('state'); + + const cookieStore = cookies(); + const storedState = cookieStore.get(STATE_COOKIE)?.value; + + cookieStore.delete(STATE_COOKIE); + + const origin = request.nextUrl.origin; + + if (!code || !incomingState || !storedState || incomingState !== storedState) { + const script = ` + window.opener && window.opener.postMessage({ + type: 'HF_OAUTH_ERROR', + payload: { message: 'Invalid or expired OAuth state.' } + }, '${origin}'); + window.close(); + `; + return htmlResponse(script.trim()); + } + + const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`; + + try { + const tokenResponse = await fetch(TOKEN_ENDPOINT, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + grant_type: 'authorization_code', + code, + redirect_uri: redirectUri, + client_id: clientId, + client_secret: clientSecret, + }), + }); + + if (!tokenResponse.ok) { + const errorPayload = await tokenResponse.json().catch(() => ({})); + throw new Error(errorPayload?.error_description || 'Failed to exchange code for token'); + } + + const tokenData = await tokenResponse.json(); + const accessToken = tokenData?.access_token; + if (!accessToken) { + throw new Error('Access token missing in response'); + } + + const userResponse = await fetch(USERINFO_ENDPOINT, { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }); + + if (!userResponse.ok) { + throw new Error('Failed to fetch user info'); + } + + const profile = await userResponse.json(); + const namespace = profile?.preferred_username || profile?.name || 'user'; + + const script = ` + window.opener && window.opener.postMessage({ + type: 'HF_OAUTH_SUCCESS', + payload: { + token: ${JSON.stringify(accessToken)}, + namespace: ${JSON.stringify(namespace)}, + } + }, '${origin}'); + window.close(); + `; + + return htmlResponse(script.trim()); + } catch (error: any) { + const message = error?.message || 'OAuth flow failed'; + const script = ` + window.opener && window.opener.postMessage({ + type: 'HF_OAUTH_ERROR', + payload: { message: ${JSON.stringify(message)} } + }, '${origin}'); + window.close(); + `; + + return htmlResponse(script.trim()); + } +} diff --git a/src/app/api/auth/hf/login/route.ts b/src/app/api/auth/hf/login/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..22c252217d8b94f9db7a495892a79193df05786a --- /dev/null +++ b/src/app/api/auth/hf/login/route.ts @@ -0,0 +1,36 @@ +import { randomUUID } from 'crypto'; +import { NextRequest, NextResponse } from 'next/server'; + +const HF_AUTHORIZE_URL = 'https://huggingface.co/oauth/authorize'; +const STATE_COOKIE = 'hf_oauth_state'; + +export async function GET(request: NextRequest) { + const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID; + if (!clientId) { + return NextResponse.json({ error: 'OAuth client ID not configured' }, { status: 500 }); + } + + const state = randomUUID(); + const origin = request.nextUrl.origin; + const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`; + + const authorizeUrl = new URL(HF_AUTHORIZE_URL); + authorizeUrl.searchParams.set('response_type', 'code'); + authorizeUrl.searchParams.set('client_id', clientId); + authorizeUrl.searchParams.set('redirect_uri', redirectUri); + authorizeUrl.searchParams.set('scope', 'openid profile read-repos'); + authorizeUrl.searchParams.set('state', state); + + const response = NextResponse.redirect(authorizeUrl.toString(), { status: 302 }); + response.cookies.set({ + name: STATE_COOKIE, + value: state, + httpOnly: true, + sameSite: 'lax', + secure: process.env.NODE_ENV === 'production', + maxAge: 60 * 5, + path: '/', + }); + + return response; +} diff --git a/src/app/api/auth/hf/validate/route.ts b/src/app/api/auth/hf/validate/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..32dc41fb4d3a7e82d8434ce577aa9e563c349203 --- /dev/null +++ b/src/app/api/auth/hf/validate/route.ts @@ -0,0 +1,22 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { whoAmI } from '@huggingface/hub'; + +export async function POST(request: NextRequest) { + try { + const body = await request.json().catch(() => ({})); + const token = (body?.token || '').trim(); + + if (!token) { + return NextResponse.json({ error: 'Token is required' }, { status: 400 }); + } + + const info = await whoAmI({ accessToken: token }); + return NextResponse.json({ + name: info?.name || info?.username || 'user', + email: info?.email || null, + orgs: info?.orgs || [], + }); + } catch (error: any) { + return NextResponse.json({ error: error?.message || 'Invalid token' }, { status: 401 }); + } +} diff --git a/src/app/api/auth/route.ts b/src/app/api/auth/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..1dc229739fbbeaabf307e3be544dd7e2bc8ab66f --- /dev/null +++ b/src/app/api/auth/route.ts @@ -0,0 +1,6 @@ +import { NextResponse } from 'next/server'; + +export async function GET() { + // if this gets hit, auth has already been verified + return NextResponse.json({ isAuthenticated: true }); +} diff --git a/src/app/api/caption/get/route.ts b/src/app/api/caption/get/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..4f8d2818318805f97a80370e1a9cfc584cd9dc26 --- /dev/null +++ b/src/app/api/caption/get/route.ts @@ -0,0 +1,46 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: NextRequest) { + + const body = await request.json(); + const { imgPath } = body; + console.log('Received POST request for caption:', imgPath); + try { + // Decode the path + const filepath = imgPath; + console.log('Decoded image path:', filepath); + + // caption name is the filepath without extension but with .txt + const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.txt'; + + // Get allowed directories + const allowedDir = await getDatasetsRoot(); + + // Security check: Ensure path is in allowed directory + const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..'); + + if (!isAllowed) { + console.warn(`Access denied: ${filepath} not in ${allowedDir}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Check if file exists + if (!fs.existsSync(captionPath)) { + // send back blank string if caption file does not exist + return new NextResponse(''); + } + + // Read caption file + const caption = fs.readFileSync(captionPath, 'utf-8'); + + // Return caption + return new NextResponse(caption); + } catch (error) { + console.error('Error getting caption:', error); + return new NextResponse('Error getting caption', { status: 500 }); + } +} diff --git a/src/app/api/datasets/create/route.tsx b/src/app/api/datasets/create/route.tsx new file mode 100644 index 0000000000000000000000000000000000000000..e005d058f3423db41f4830b69a1d51c7872d1351 --- /dev/null +++ b/src/app/api/datasets/create/route.tsx @@ -0,0 +1,25 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + let { name } = body; + // clean name by making lower case, removing special characters, and replacing spaces with underscores + name = name.toLowerCase().replace(/[^a-z0-9]+/g, '_'); + + let datasetsPath = await getDatasetsRoot(); + let datasetPath = path.join(datasetsPath, name); + + // if folder doesnt exist, create it + if (!fs.existsSync(datasetPath)) { + fs.mkdirSync(datasetPath); + } + + return NextResponse.json({ success: true, name: name, path: datasetPath }); + } catch (error) { + return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); + } +} diff --git a/src/app/api/datasets/delete/route.tsx b/src/app/api/datasets/delete/route.tsx new file mode 100644 index 0000000000000000000000000000000000000000..9a1d970ee415c9d040596854ce74ad5401859259 --- /dev/null +++ b/src/app/api/datasets/delete/route.tsx @@ -0,0 +1,24 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { name } = body; + let datasetsPath = await getDatasetsRoot(); + let datasetPath = path.join(datasetsPath, name); + + // if folder doesnt exist, ignore + if (!fs.existsSync(datasetPath)) { + return NextResponse.json({ success: true }); + } + + // delete it and return success + fs.rmdirSync(datasetPath, { recursive: true }); + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); + } +} diff --git a/src/app/api/datasets/list/route.ts b/src/app/api/datasets/list/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..dc829c65f3cab2829221f85341967fc1b52a921c --- /dev/null +++ b/src/app/api/datasets/list/route.ts @@ -0,0 +1,25 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function GET() { + try { + let datasetsPath = await getDatasetsRoot(); + + // if folder doesnt exist, create it + if (!fs.existsSync(datasetsPath)) { + fs.mkdirSync(datasetsPath); + } + + // find all the folders in the datasets folder + let folders = fs + .readdirSync(datasetsPath, { withFileTypes: true }) + .filter(dirent => dirent.isDirectory()) + .filter(dirent => !dirent.name.startsWith('.')) + .map(dirent => dirent.name); + + return NextResponse.json(folders); + } catch (error) { + return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 }); + } +} diff --git a/src/app/api/datasets/listImages/route.ts b/src/app/api/datasets/listImages/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..06dca84ae780c7fddb200fc6de422b7a42e309ea --- /dev/null +++ b/src/app/api/datasets/listImages/route.ts @@ -0,0 +1,61 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + const datasetsPath = await getDatasetsRoot(); + const body = await request.json(); + const { datasetName } = body; + const datasetFolder = path.join(datasetsPath, datasetName); + + try { + // Check if folder exists + if (!fs.existsSync(datasetFolder)) { + return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 }); + } + + // Find all images recursively + const imageFiles = findImagesRecursively(datasetFolder); + + // Format response + const result = imageFiles.map(imgPath => ({ + img_path: imgPath, + })); + + return NextResponse.json({ images: result }); + } catch (error) { + console.error('Error finding images:', error); + return NextResponse.json({ error: 'Failed to process request' }, { status: 500 }); + } +} + +/** + * Recursively finds all image files in a directory and its subdirectories + * @param dir Directory to search + * @returns Array of absolute paths to image files + */ +function findImagesRecursively(dir: string): string[] { + const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp', '.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv']; + let results: string[] = []; + + const items = fs.readdirSync(dir); + + for (const item of items) { + const itemPath = path.join(dir, item); + const stat = fs.statSync(itemPath); + + if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) { + // If it's a directory, recursively search it + results = results.concat(findImagesRecursively(itemPath)); + } else { + // If it's a file, check if it's an image + const ext = path.extname(itemPath).toLowerCase(); + if (imageExtensions.includes(ext)) { + results.push(itemPath); + } + } + } + + return results; +} diff --git a/src/app/api/datasets/upload/route.ts b/src/app/api/datasets/upload/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..51aff81fd3bf4b091f10a1df9f2da887910f4753 --- /dev/null +++ b/src/app/api/datasets/upload/route.ts @@ -0,0 +1,57 @@ +// src/app/api/datasets/upload/route.ts +import { NextRequest, NextResponse } from 'next/server'; +import { writeFile, mkdir } from 'fs/promises'; +import { join } from 'path'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: NextRequest) { + try { + const datasetsPath = await getDatasetsRoot(); + if (!datasetsPath) { + return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 }); + } + const formData = await request.formData(); + const files = formData.getAll('files'); + const datasetName = formData.get('datasetName') as string; + + if (!files || files.length === 0) { + return NextResponse.json({ error: 'No files provided' }, { status: 400 }); + } + + // Create upload directory if it doesn't exist + const uploadDir = join(datasetsPath, datasetName); + await mkdir(uploadDir, { recursive: true }); + + const savedFiles: string[] = []; + + // Process files sequentially to avoid overwhelming the system + for (let i = 0; i < files.length; i++) { + const file = files[i] as any; + const bytes = await file.arrayBuffer(); + const buffer = Buffer.from(bytes); + + // Clean filename and ensure it's unique + const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_'); + const filePath = join(uploadDir, fileName); + + await writeFile(filePath, buffer); + savedFiles.push(fileName); + } + + return NextResponse.json({ + message: 'Files uploaded successfully', + files: savedFiles, + }); + } catch (error) { + console.error('Upload error:', error); + return NextResponse.json({ error: 'Error uploading files' }, { status: 500 }); + } +} + +// Increase payload size limit (default is 4mb) +export const config = { + api: { + bodyParser: false, + responseLimit: '50mb', + }, +}; diff --git a/src/app/api/files/[...filePath]/route.ts b/src/app/api/files/[...filePath]/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..46eb5c4ab08b9c02ba4ff8d0fe7f6dc2cd15442a --- /dev/null +++ b/src/app/api/files/[...filePath]/route.ts @@ -0,0 +1,116 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot, getTrainingFolder } from '@/server/settings'; + +export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) { + const { filePath } = await params; + try { + // Decode the path + const decodedFilePath = decodeURIComponent(filePath); + + // Get allowed directories + const datasetRoot = await getDatasetsRoot(); + const trainingRoot = await getTrainingFolder(); + const allowedDirs = [datasetRoot, trainingRoot]; + + // Security check: Ensure path is in allowed directory + const isAllowed = + allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..'); + + if (!isAllowed) { + console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Check if file exists + if (!fs.existsSync(decodedFilePath)) { + console.warn(`File not found: ${decodedFilePath}`); + return new NextResponse('File not found', { status: 404 }); + } + + // Get file info + const stat = fs.statSync(decodedFilePath); + if (!stat.isFile()) { + return new NextResponse('Not a file', { status: 400 }); + } + + // Get filename for Content-Disposition + const filename = path.basename(decodedFilePath); + + // Determine content type + const ext = path.extname(decodedFilePath).toLowerCase(); + const contentTypeMap: { [key: string]: string } = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp', + '.svg': 'image/svg+xml', + '.bmp': 'image/bmp', + '.safetensors': 'application/octet-stream', + '.zip': 'application/zip', + // Videos + '.mp4': 'video/mp4', + '.avi': 'video/x-msvideo', + '.mov': 'video/quicktime', + '.mkv': 'video/x-matroska', + '.wmv': 'video/x-ms-wmv', + '.m4v': 'video/x-m4v', + '.flv': 'video/x-flv' + }; + + const contentType = contentTypeMap[ext] || 'application/octet-stream'; + + // Get range header for partial content support + const range = request.headers.get('range'); + + // Common headers for better download handling + const commonHeaders = { + 'Content-Type': contentType, + 'Accept-Ranges': 'bytes', + 'Cache-Control': 'public, max-age=86400', + 'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`, + 'X-Content-Type-Options': 'nosniff', + }; + + if (range) { + // Parse range header + const parts = range.replace(/bytes=/, '').split('-'); + const start = parseInt(parts[0], 10); + const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks + const chunkSize = end - start + 1; + + const fileStream = fs.createReadStream(decodedFilePath, { + start, + end, + highWaterMark: 64 * 1024, // 64KB buffer + }); + + return new NextResponse(fileStream as any, { + status: 206, + headers: { + ...commonHeaders, + 'Content-Range': `bytes ${start}-${end}/${stat.size}`, + 'Content-Length': String(chunkSize), + }, + }); + } else { + // For full file download, read directly without streaming wrapper + const fileStream = fs.createReadStream(decodedFilePath, { + highWaterMark: 64 * 1024, // 64KB buffer + }); + + return new NextResponse(fileStream as any, { + headers: { + ...commonHeaders, + 'Content-Length': String(stat.size), + }, + }); + } + } catch (error) { + console.error('Error serving file:', error); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/src/app/api/gpu/route.ts b/src/app/api/gpu/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..8b11dbb0e6d8e8de0f191bb1e78bb8687376881a --- /dev/null +++ b/src/app/api/gpu/route.ts @@ -0,0 +1,121 @@ +import { NextResponse } from 'next/server'; +import { exec } from 'child_process'; +import { promisify } from 'util'; +import os from 'os'; + +const execAsync = promisify(exec); + +export async function GET() { + try { + // Get platform + const platform = os.platform(); + const isWindows = platform === 'win32'; + + // Check if nvidia-smi is available + const hasNvidiaSmi = await checkNvidiaSmi(isWindows); + + if (!hasNvidiaSmi) { + return NextResponse.json({ + hasNvidiaSmi: false, + gpus: [], + error: 'nvidia-smi not found or not accessible', + }); + } + + // Get GPU stats + const gpuStats = await getGpuStats(isWindows); + + return NextResponse.json({ + hasNvidiaSmi: true, + gpus: gpuStats, + }); + } catch (error) { + console.error('Error fetching NVIDIA GPU stats:', error); + return NextResponse.json( + { + hasNvidiaSmi: false, + gpus: [], + error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`, + }, + { status: 500 }, + ); + } +} + +async function checkNvidiaSmi(isWindows: boolean): Promise { + try { + if (isWindows) { + // Check if nvidia-smi is available on Windows + // It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe + // but we'll just try to run it directly as it may be in PATH + await execAsync('nvidia-smi -L'); + } else { + // Linux/macOS check + await execAsync('which nvidia-smi'); + } + return true; + } catch (error) { + return false; + } +} + +async function getGpuStats(isWindows: boolean) { + // Command is the same for both platforms, but the path might be different + const command = + 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits'; + + // Execute command + const { stdout } = await execAsync(command); + + // Parse CSV output + const gpus = stdout + .trim() + .split('\n') + .map(line => { + const [ + index, + name, + driverVersion, + temperature, + gpuUtil, + memoryUtil, + memoryTotal, + memoryFree, + memoryUsed, + powerDraw, + powerLimit, + clockGraphics, + clockMemory, + fanSpeed, + ] = line.split(', ').map(item => item.trim()); + + return { + index: parseInt(index), + name, + driverVersion, + temperature: parseInt(temperature), + utilization: { + gpu: parseInt(gpuUtil), + memory: parseInt(memoryUtil), + }, + memory: { + total: parseInt(memoryTotal), + free: parseInt(memoryFree), + used: parseInt(memoryUsed), + }, + power: { + draw: parseFloat(powerDraw), + limit: parseFloat(powerLimit), + }, + clocks: { + graphics: parseInt(clockGraphics), + memory: parseInt(clockMemory), + }, + fan: { + speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0 + }, + }; + }); + + return gpus; +} diff --git a/src/app/api/hf-hub/route.ts b/src/app/api/hf-hub/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..afdfb64c599b6fbad3c832d7450176ba3ca2b2c0 --- /dev/null +++ b/src/app/api/hf-hub/route.ts @@ -0,0 +1,165 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { whoAmI, createRepo, uploadFiles, datasetInfo } from '@huggingface/hub'; +import { readdir, stat } from 'fs/promises'; +import path from 'path'; + +export async function POST(request: NextRequest) { + try { + const body = await request.json(); + const { action, token, namespace, datasetName, datasetPath, datasetId } = body; + + if (!token) { + return NextResponse.json({ error: 'HF token is required' }, { status: 400 }); + } + + switch (action) { + case 'whoami': + try { + const user = await whoAmI({ accessToken: token }); + return NextResponse.json({ user }); + } catch (error) { + return NextResponse.json({ error: 'Invalid token or network error' }, { status: 401 }); + } + + case 'createDataset': + try { + if (!namespace || !datasetName) { + return NextResponse.json({ error: 'Namespace and dataset name required' }, { status: 400 }); + } + + const repoId = `datasets/${namespace}/${datasetName}`; + + // Create repository + await createRepo({ + repo: repoId, + accessToken: token, + private: false, + }); + + return NextResponse.json({ success: true, repoId }); + } catch (error: any) { + if (error.message?.includes('already exists')) { + return NextResponse.json({ success: true, repoId: `${namespace}/${datasetName}`, exists: true }); + } + return NextResponse.json({ error: error.message || 'Failed to create dataset' }, { status: 500 }); + } + + case 'uploadDataset': + try { + if (!namespace || !datasetName || !datasetPath) { + return NextResponse.json({ error: 'Missing required parameters' }, { status: 400 }); + } + + const repoId = `datasets/${namespace}/${datasetName}`; + + // Check if directory exists + try { + await stat(datasetPath); + } catch { + return NextResponse.json({ error: 'Dataset path does not exist' }, { status: 400 }); + } + + // Read files from directory and upload them + const files = await readdir(datasetPath); + const filesToUpload = []; + + for (const fileName of files) { + const filePath = path.join(datasetPath, fileName); + const fileStats = await stat(filePath); + + if (fileStats.isFile()) { + filesToUpload.push({ + path: fileName, + content: new URL(`file://${filePath}`) + }); + } + } + + if (filesToUpload.length > 0) { + await uploadFiles({ + repo: repoId, + accessToken: token, + files: filesToUpload, + }); + } + + return NextResponse.json({ success: true, repoId }); + } catch (error: any) { + console.error('Upload error:', error); + return NextResponse.json({ error: error.message || 'Failed to upload dataset' }, { status: 500 }); + } + + case 'listFiles': + try { + if (!datasetPath) { + return NextResponse.json({ error: 'Dataset path required' }, { status: 400 }); + } + + const files = await readdir(datasetPath, { withFileTypes: true }); + const imageExtensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp']; + + const imageFiles = files + .filter(file => file.isFile()) + .filter(file => imageExtensions.some(ext => file.name.toLowerCase().endsWith(ext))) + .map(file => ({ + name: file.name, + path: path.join(datasetPath, file.name), + })); + + const captionFiles = files + .filter(file => file.isFile()) + .filter(file => file.name.endsWith('.txt')) + .map(file => ({ + name: file.name, + path: path.join(datasetPath, file.name), + })); + + return NextResponse.json({ + images: imageFiles, + captions: captionFiles, + total: imageFiles.length + }); + } catch (error: any) { + return NextResponse.json({ error: error.message || 'Failed to list files' }, { status: 500 }); + } + + case 'validateDataset': + try { + if (!datasetId) { + return NextResponse.json({ error: 'Dataset ID required' }, { status: 400 }); + } + + // Try to get dataset info to validate it exists and is accessible + const dataset = await datasetInfo({ + name: datasetId, + accessToken: token, + }); + + return NextResponse.json({ + exists: true, + dataset: { + id: dataset.id, + author: dataset.author, + downloads: dataset.downloads, + likes: dataset.likes, + private: dataset.private, + } + }); + } catch (error: any) { + if (error.message?.includes('404') || error.message?.includes('not found')) { + return NextResponse.json({ exists: false }, { status: 200 }); + } + if (error.message?.includes('401') || error.message?.includes('403')) { + return NextResponse.json({ error: 'Dataset not accessible with current token' }, { status: 403 }); + } + return NextResponse.json({ error: error.message || 'Failed to validate dataset' }, { status: 500 }); + } + + default: + return NextResponse.json({ error: 'Invalid action' }, { status: 400 }); + } + } catch (error: any) { + console.error('HF Hub API error:', error); + return NextResponse.json({ error: error.message || 'Internal server error' }, { status: 500 }); + } +} \ No newline at end of file diff --git a/src/app/api/hf-jobs/route.ts b/src/app/api/hf-jobs/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..12fe64374cc3552584f8f9fdbe2948fa47996b62 --- /dev/null +++ b/src/app/api/hf-jobs/route.ts @@ -0,0 +1,761 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { spawn } from 'child_process'; +import { writeFile } from 'fs/promises'; +import path from 'path'; +import { tmpdir } from 'os'; + +export async function POST(request: NextRequest) { + try { + const body = await request.json(); + const { action, token, hardware, namespace, jobConfig, datasetRepo } = body; + + switch (action) { + case 'checkStatus': + try { + if (!token || !jobConfig?.hf_job_id) { + return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 }); + } + + const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id); + return NextResponse.json({ status: jobStatus }); + } catch (error: any) { + console.error('Job status check error:', error); + return NextResponse.json({ error: error.message }, { status: 500 }); + } + + case 'generateScript': + try { + const uvScript = generateUVScript({ + jobConfig, + datasetRepo, + namespace, + token: token || 'YOUR_HF_TOKEN', + }); + + return NextResponse.json({ + script: uvScript, + filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py` + }); + } catch (error: any) { + return NextResponse.json({ error: error.message }, { status: 500 }); + } + + case 'submitJob': + try { + if (!token || !hardware) { + return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 }); + } + + // Generate UV script + const uvScript = generateUVScript({ + jobConfig, + datasetRepo, + namespace, + token, + }); + + // Write script to temporary file + const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`); + await writeFile(scriptPath, uvScript); + + // Submit HF job using uv run + const jobId = await submitHFJobUV(token, hardware, scriptPath); + + return NextResponse.json({ + success: true, + jobId, + message: `Job submitted successfully with ID: ${jobId}` + }); + } catch (error: any) { + console.error('Job submission error:', error); + return NextResponse.json({ error: error.message }, { status: 500 }); + } + + default: + return NextResponse.json({ error: 'Invalid action' }, { status: 400 }); + } + } catch (error: any) { + console.error('HF Jobs API error:', error); + return NextResponse.json({ error: error.message }, { status: 500 }); + } +} + +function generateUVScript({ jobConfig, datasetRepo, namespace, token }: { + jobConfig: any; + datasetRepo: string; + namespace: string; + token: string; +}) { + const config = jobConfig.config; + const process = config.process[0]; + + return `# /// script +# dependencies = [ +# "torch>=2.0.0", +# "torchvision", +# "torchao==0.10.0", +# "safetensors", +# "diffusers @ git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63", +# "transformers==4.52.4", +# "lycoris-lora==1.8.3", +# "flatten_json", +# "pyyaml", +# "oyaml", +# "tensorboard", +# "kornia", +# "invisible-watermark", +# "einops", +# "accelerate", +# "toml", +# "albumentations==1.4.15", +# "albucore==0.0.16", +# "pydantic", +# "omegaconf", +# "k-diffusion", +# "open_clip_torch", +# "timm", +# "prodigyopt", +# "controlnet_aux==0.0.10", +# "python-dotenv", +# "bitsandbytes", +# "hf_transfer", +# "lpips", +# "pytorch_fid", +# "optimum-quanto==0.2.4", +# "sentencepiece", +# "huggingface_hub", +# "peft", +# "python-slugify", +# "opencv-python-headless", +# "pytorch-wavelets==1.3.0", +# "matplotlib==3.10.1", +# "setuptools==69.5.1", +# "datasets==4.0.0", +# "pyarrow==20.0.0", +# "pillow", +# "ftfy", +# ] +# /// + +import os +import sys +import subprocess +import argparse +import oyaml as yaml +from datasets import load_dataset +from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download +import tempfile +import shutil +import glob +from PIL import Image + +def setup_ai_toolkit(): + """Clone and setup ai-toolkit repository""" + repo_dir = "ai-toolkit" + if not os.path.exists(repo_dir): + print("Cloning ai-toolkit repository...") + subprocess.run( + ["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir], + check=True + ) + sys.path.insert(0, os.path.abspath(repo_dir)) + return repo_dir + +def download_dataset(dataset_repo: str, local_path: str): + """Download dataset from HF Hub as files""" + print(f"Downloading dataset from {dataset_repo}...") + + # Create local dataset directory + os.makedirs(local_path, exist_ok=True) + + # Use snapshot_download to get the dataset files directly + from huggingface_hub import snapshot_download + + try: + # First try to download as a structured dataset + dataset = load_dataset(dataset_repo, split="train") + + # Download images and captions from structured dataset + for i, item in enumerate(dataset): + # Save image + if "image" in item: + image_path = os.path.join(local_path, f"image_{i:06d}.jpg") + image = item["image"] + + # Convert RGBA to RGB if necessary (for JPEG compatibility) + if image.mode == 'RGBA': + # Create a white background and paste the RGBA image on it + background = Image.new('RGB', image.size, (255, 255, 255)) + background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask + image = background + elif image.mode not in ['RGB', 'L']: + # Convert any other mode to RGB + image = image.convert('RGB') + + image.save(image_path, 'JPEG') + + # Save caption + if "text" in item: + caption_path = os.path.join(local_path, f"image_{i:06d}.txt") + with open(caption_path, "w", encoding="utf-8") as f: + f.write(item["text"]) + + print(f"Downloaded {len(dataset)} items to {local_path}") + + except Exception as e: + print(f"Failed to load as structured dataset: {e}") + print("Attempting to download raw files...") + + # Download the dataset repository as files + temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset") + + # Copy all image and text files to the local path + import glob + import shutil + + print(f"Downloaded repo to: {temp_repo_path}") + print(f"Contents: {os.listdir(temp_repo_path)}") + + # Find all image files + image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG'] + image_files = [] + for ext in image_extensions: + pattern = os.path.join(temp_repo_path, "**", ext) + found_files = glob.glob(pattern, recursive=True) + image_files.extend(found_files) + print(f"Pattern {pattern} found {len(found_files)} files") + + # Find all text files + text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True) + + print(f"Found {len(image_files)} image files and {len(text_files)} text files") + + # Copy image files + for i, img_file in enumerate(image_files): + dest_path = os.path.join(local_path, f"image_{i:06d}.jpg") + + # Load and convert image if needed + try: + with Image.open(img_file) as image: + if image.mode == 'RGBA': + background = Image.new('RGB', image.size, (255, 255, 255)) + background.paste(image, mask=image.split()[-1]) + image = background + elif image.mode not in ['RGB', 'L']: + image = image.convert('RGB') + + image.save(dest_path, 'JPEG') + except Exception as img_error: + print(f"Error processing image {img_file}: {img_error}") + continue + + # Copy text files (captions) + for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images + dest_path = os.path.join(local_path, f"image_{i:06d}.txt") + try: + shutil.copy2(txt_file, dest_path) + except Exception as txt_error: + print(f"Error copying text file {txt_file}: {txt_error}") + continue + + print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}") + +def create_config(dataset_path: str, output_path: str): + """Create training configuration""" + import json + + # Load config from JSON string and fix boolean/null values for Python + config_str = """${JSON.stringify(jobConfig, null, 2)}""" + config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None') + config = eval(config_str) + + # Update paths for cloud environment + config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_path + config["config"]["process"][0]["training_folder"] = output_path + + # Remove sqlite_db_path as it's not needed for cloud training + if "sqlite_db_path" in config["config"]["process"][0]: + del config["config"]["process"][0]["sqlite_db_path"] + + # Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies + if config["config"]["process"][0]["type"] == "ui_trainer": + config["config"]["process"][0]["type"] = "sd_trainer" + + return config + +def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict): + """Upload trained model to HF Hub with README generation and proper file organization""" + import tempfile + import shutil + import glob + import re + import yaml + from datetime import datetime + from huggingface_hub import create_repo, upload_file, HfApi + + try: + repo_id = f"{namespace}/{model_name}" + + # Create repository + create_repo(repo_id=repo_id, token=token, exist_ok=True) + + print(f"Uploading model to {repo_id}...") + + # Create temporary directory for organized upload + with tempfile.TemporaryDirectory() as temp_upload_dir: + api = HfApi() + + # 1. Find and upload model files to root directory + safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True) + json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True) + txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True) + + uploaded_files = [] + + # Upload .safetensors files to root + for file_path in safetensors_files: + filename = os.path.basename(file_path) + print(f"Uploading {filename} to repository root...") + api.upload_file( + path_or_fileobj=file_path, + path_in_repo=filename, + repo_id=repo_id, + token=token + ) + uploaded_files.append(filename) + + # Upload relevant JSON config files to root (skip metadata.json and other internal files) + config_files_uploaded = [] + for file_path in json_files: + filename = os.path.basename(file_path) + # Only upload important config files, skip internal metadata + if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']): + print(f"Uploading {filename} to repository root...") + api.upload_file( + path_or_fileobj=file_path, + path_in_repo=filename, + repo_id=repo_id, + token=token + ) + uploaded_files.append(filename) + config_files_uploaded.append(filename) + + # 2. Handle sample images + samples_uploaded = [] + samples_dir = os.path.join(output_path, "samples") + if os.path.isdir(samples_dir): + print("Uploading sample images...") + # Create samples directory in repo + for filename in os.listdir(samples_dir): + if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): + file_path = os.path.join(samples_dir, filename) + repo_path = f"samples/{filename}" + api.upload_file( + path_or_fileobj=file_path, + path_in_repo=repo_path, + repo_id=repo_id, + token=token + ) + samples_uploaded.append(repo_path) + + # 3. Generate and upload README.md + readme_content = generate_model_card_readme( + repo_id=repo_id, + config=config, + model_name=model_name, + samples_dir=samples_dir if os.path.isdir(samples_dir) else None, + uploaded_files=uploaded_files + ) + + # Create README.md file and upload to root + readme_path = os.path.join(temp_upload_dir, "README.md") + with open(readme_path, "w", encoding="utf-8") as f: + f.write(readme_content) + + print("Uploading README.md to repository root...") + api.upload_file( + path_or_fileobj=readme_path, + path_in_repo="README.md", + repo_id=repo_id, + token=token + ) + + print(f"Model uploaded successfully to https://huggingface.co/{repo_id}") + print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md") + + except Exception as e: + print(f"Failed to upload model: {e}") + raise e + +def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str: + """Generate README.md content for the model card based on AI Toolkit's implementation""" + import re + import yaml + import os + + try: + # Extract configuration details + process_config = config.get("config", {}).get("process", [{}])[0] + model_config = process_config.get("model", {}) + train_config = process_config.get("train", {}) + sample_config = process_config.get("sample", {}) + + # Gather model info + base_model = model_config.get("name_or_path", "unknown") + trigger_word = process_config.get("trigger_word") + arch = model_config.get("arch", "") + + # Determine license based on base model + if "FLUX.1-schnell" in base_model: + license_info = {"license": "apache-2.0"} + elif "FLUX.1-dev" in base_model: + license_info = { + "license": "other", + "license_name": "flux-1-dev-non-commercial-license", + "license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md" + } + else: + license_info = {"license": "creativeml-openrail-m"} + + # Generate tags based on model architecture + tags = ["text-to-image"] + + if "xl" in arch.lower(): + tags.append("stable-diffusion-xl") + if "flux" in arch.lower(): + tags.append("flux") + if "lumina" in arch.lower(): + tags.append("lumina2") + if "sd3" in arch.lower() or "v3" in arch.lower(): + tags.append("sd3") + + # Add LoRA-specific tags + tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"]) + + # Generate widgets from sample images and prompts + widgets = [] + if samples_dir and os.path.isdir(samples_dir): + sample_prompts = sample_config.get("samples", []) + if not sample_prompts: + # Fallback to old format + sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])] + + # Get sample image files + sample_files = [] + if os.path.isdir(samples_dir): + for filename in os.listdir(samples_dir): + if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): + # Parse filename pattern: timestamp__steps_index.jpg + match = re.search(r"__(\d+)_(\d+)\.jpg$", filename) + if match: + steps, index = int(match.group(1)), int(match.group(2)) + # Only use samples from final training step + final_steps = train_config.get("steps", 1000) + if steps == final_steps: + sample_files.append((index, f"samples/{filename}")) + + # Sort by index and create widgets + sample_files.sort(key=lambda x: x[0]) + + for i, prompt_obj in enumerate(sample_prompts): + prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj) + if i < len(sample_files): + _, image_path = sample_files[i] + widgets.append({ + "text": prompt, + "output": {"url": image_path} + }) + + # Determine torch dtype based on model + dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16" + + # Find the main safetensors file for usage example + main_safetensors = f"{model_name}.safetensors" + if uploaded_files: + safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')] + if safetensors_files: + main_safetensors = safetensors_files[0] + + # Construct YAML frontmatter + frontmatter = { + "tags": tags, + "base_model": base_model, + **license_info + } + + if widgets: + frontmatter["widget"] = widgets + + if trigger_word: + frontmatter["instance_prompt"] = trigger_word + + # Get first prompt for usage example + usage_prompt = trigger_word or "a beautiful landscape" + if widgets: + usage_prompt = widgets[0]["text"] + elif trigger_word: + usage_prompt = trigger_word + + # Construct README content + trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined." + + # Build YAML frontmatter string + frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip() + + readme_content = f"""--- +{frontmatter_yaml} +--- + +# {model_name} + +Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) + + + +## Trigger words + +{trigger_section} + +## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc. + +Weights for this model are available in Safetensors format. + +[Download]({repo_id}/tree/main) them in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +\`\`\`py +from diffusers import AutoPipelineForText2Image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}') +image = pipeline('{usage_prompt}').images[0] +image.save("my_image.png") +\`\`\` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +""" + return readme_content + + except Exception as e: + print(f"Error generating README: {e}") + # Fallback simple README + return f"""# {model_name} + +Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) + +## Download model + +Weights for this model are available in Safetensors format. + +[Download]({repo_id}/tree/main) them in the Files & versions tab. +""" + +def main(): + # Setup environment - token comes from HF Jobs secrets + if "HF_TOKEN" not in os.environ: + raise ValueError("HF_TOKEN environment variable not set") + + # Install system dependencies for headless operation + print("Installing system dependencies...") + try: + subprocess.run(["apt-get", "update"], check=True, capture_output=True) + subprocess.run([ + "apt-get", "install", "-y", + "libgl1-mesa-glx", + "libglib2.0-0", + "libsm6", + "libxext6", + "libxrender-dev", + "libgomp1", + "ffmpeg" + ], check=True, capture_output=True) + print("System dependencies installed successfully") + except subprocess.CalledProcessError as e: + print(f"Failed to install system dependencies: {e}") + print("Continuing without system dependencies...") + + # Setup ai-toolkit + toolkit_dir = setup_ai_toolkit() + + # Create temporary directories + with tempfile.TemporaryDirectory() as temp_dir: + dataset_path = os.path.join(temp_dir, "dataset") + output_path = os.path.join(temp_dir, "output") + + # Download dataset + download_dataset("${datasetRepo}", dataset_path) + + # Create config + config = create_config(dataset_path, output_path) + config_path = os.path.join(temp_dir, "config.yaml") + + with open(config_path, "w") as f: + yaml.dump(config, f, default_flow_style=False) + + # Run training + print("Starting training...") + os.chdir(toolkit_dir) + + subprocess.run([ + sys.executable, "run.py", + config_path + ], check=True) + + print("Training completed!") + + # Upload results + model_name = f"${jobConfig.config.name}-lora" + upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config) + +if __name__ == "__main__": + main() +`; +} + +async function submitHFJobUV(token: string, hardware: string, scriptPath: string): Promise { + return new Promise((resolve, reject) => { + // Ensure token is available + if (!token) { + reject(new Error('HF_TOKEN is required')); + return; + } + + console.log('Setting up environment with HF_TOKEN for job submission'); + console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach ${scriptPath}`); + + // Use hf jobs uv run command with timeout and detach to get job ID + const childProcess = spawn('hf', [ + 'jobs', 'uv', 'run', + '--flavor', hardware, + '--timeout', '5h', + '--secrets', 'HF_TOKEN', + '--detach', + scriptPath + ], { + env: { + ...process.env, + HF_TOKEN: token + } + }); + + let output = ''; + let error = ''; + + childProcess.stdout.on('data', (data) => { + const text = data.toString(); + output += text; + console.log('HF Jobs stdout:', text); + }); + + childProcess.stderr.on('data', (data) => { + const text = data.toString(); + error += text; + console.log('HF Jobs stderr:', text); + }); + + childProcess.on('close', (code) => { + console.log('HF Jobs process closed with code:', code); + console.log('Full output:', output); + console.log('Full error:', error); + + if (code === 0) { + // With --detach flag, the output should be just the job ID + const fullText = (output + ' ' + error).trim(); + + // Updated patterns to handle variable-length hex job IDs (16-24+ characters) + const jobIdPatterns = [ + /Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac" + /job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac" + /Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac" + /created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac" + /submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac" + /https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern + /([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string + ]; + + let jobId = 'unknown'; + + for (const pattern of jobIdPatterns) { + const match = fullText.match(pattern); + if (match && match[1] && match[1] !== 'started') { + jobId = match[1]; + console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`); + break; + } + } + + resolve(jobId); + } else { + reject(new Error(error || output || 'Failed to submit job')); + } + }); + + childProcess.on('error', (err) => { + console.error('HF Jobs process error:', err); + reject(new Error(`Process error: ${err.message}`)); + }); + }); +} + +async function checkHFJobStatus(token: string, jobId: string): Promise { + return new Promise((resolve, reject) => { + console.log(`Checking HF Job status for: ${jobId}`); + + const childProcess = spawn('hf', [ + 'jobs', 'inspect', jobId + ], { + env: { + ...process.env, + HF_TOKEN: token + } + }); + + let output = ''; + let error = ''; + + childProcess.stdout.on('data', (data) => { + const text = data.toString(); + output += text; + }); + + childProcess.stderr.on('data', (data) => { + const text = data.toString(); + error += text; + }); + + childProcess.on('close', (code) => { + if (code === 0) { + try { + // Parse the JSON output from hf jobs inspect + const jobInfo = JSON.parse(output); + if (Array.isArray(jobInfo) && jobInfo.length > 0) { + const job = jobInfo[0]; + resolve({ + id: job.id, + status: job.status?.stage || 'UNKNOWN', + message: job.status?.message, + created_at: job.created_at, + flavor: job.flavor, + url: job.url, + }); + } else { + reject(new Error('Invalid job info response')); + } + } catch (parseError: any) { + console.error('Failed to parse job status:', parseError, output); + reject(new Error('Failed to parse job status')); + } + } else { + reject(new Error(error || output || 'Failed to check job status')); + } + }); + + childProcess.on('error', (err) => { + console.error('HF Jobs inspect process error:', err); + reject(new Error(`Process error: ${err.message}`)); + }); + }); +} \ No newline at end of file diff --git a/src/app/api/img/[...imagePath]/route.ts b/src/app/api/img/[...imagePath]/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..80fc727216dd6a64e402385078725443234e636a --- /dev/null +++ b/src/app/api/img/[...imagePath]/route.ts @@ -0,0 +1,78 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings'; + +export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) { + const { imagePath } = await params; + try { + // Decode the path + const filepath = decodeURIComponent(imagePath); + + // Get allowed directories + const datasetRoot = await getDatasetsRoot(); + const trainingRoot = await getTrainingFolder(); + const dataRoot = await getDataRoot(); + + const allowedDirs = [datasetRoot, trainingRoot, dataRoot]; + + // Security check: Ensure path is in allowed directory + const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..'); + + if (!isAllowed) { + console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Check if file exists + if (!fs.existsSync(filepath)) { + console.warn(`File not found: ${filepath}`); + return new NextResponse('File not found', { status: 404 }); + } + + // Get file info + const stat = fs.statSync(filepath); + if (!stat.isFile()) { + return new NextResponse('Not a file', { status: 400 }); + } + + // Determine content type + const ext = path.extname(filepath).toLowerCase(); + const contentTypeMap: { [key: string]: string } = { + // Images + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp', + '.svg': 'image/svg+xml', + '.bmp': 'image/bmp', + // Videos + '.mp4': 'video/mp4', + '.avi': 'video/x-msvideo', + '.mov': 'video/quicktime', + '.mkv': 'video/x-matroska', + '.wmv': 'video/x-ms-wmv', + '.m4v': 'video/x-m4v', + '.flv': 'video/x-flv' + }; + + const contentType = contentTypeMap[ext] || 'application/octet-stream'; + + // Read file as buffer + const fileBuffer = fs.readFileSync(filepath); + + // Return file with appropriate headers + return new NextResponse(fileBuffer, { + headers: { + 'Content-Type': contentType, + 'Content-Length': String(stat.size), + 'Cache-Control': 'public, max-age=86400', + }, + }); + } catch (error) { + console.error('Error serving image:', error); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/src/app/api/img/caption/route.ts b/src/app/api/img/caption/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..df4235f99986dedf253b45b802537b4b559b43ca --- /dev/null +++ b/src/app/api/img/caption/route.ts @@ -0,0 +1,29 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { imgPath, caption } = body; + let datasetsPath = await getDatasetsRoot(); + // make sure the dataset path is in the image path + if (!imgPath.startsWith(datasetsPath)) { + return NextResponse.json({ error: 'Invalid image path' }, { status: 400 }); + } + + // if img doesnt exist, ignore + if (!fs.existsSync(imgPath)) { + return NextResponse.json({ error: 'Image does not exist' }, { status: 404 }); + } + + // check for caption + const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt'; + // save caption to file + fs.writeFileSync(captionPath, caption); + + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); + } +} diff --git a/src/app/api/img/delete/route.ts b/src/app/api/img/delete/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..d4d968f8eab6f6b1d9c988c3fd86aee2d6c2fe4f --- /dev/null +++ b/src/app/api/img/delete/route.ts @@ -0,0 +1,34 @@ +import { NextResponse } from 'next/server'; +import fs from 'fs'; +import { getDatasetsRoot } from '@/server/settings'; + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { imgPath } = body; + let datasetsPath = await getDatasetsRoot(); + // make sure the dataset path is in the image path + if (!imgPath.startsWith(datasetsPath)) { + return NextResponse.json({ error: 'Invalid image path' }, { status: 400 }); + } + + // if img doesnt exist, ignore + if (!fs.existsSync(imgPath)) { + return NextResponse.json({ success: true }); + } + + // delete it and return success + fs.unlinkSync(imgPath); + + // check for caption + const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt'; + if (fs.existsSync(captionPath)) { + // delete caption file + fs.unlinkSync(captionPath); + } + + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); + } +} diff --git a/src/app/api/img/upload/route.ts b/src/app/api/img/upload/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..56615bd06c4bfee9e7aef4b81a620d4c8c7cbcb7 --- /dev/null +++ b/src/app/api/img/upload/route.ts @@ -0,0 +1,58 @@ +// src/app/api/datasets/upload/route.ts +import { NextRequest, NextResponse } from 'next/server'; +import { writeFile, mkdir } from 'fs/promises'; +import { join } from 'path'; +import { getDataRoot } from '@/server/settings'; +import {v4 as uuidv4} from 'uuid'; + +export async function POST(request: NextRequest) { + try { + const dataRoot = await getDataRoot(); + if (!dataRoot) { + return NextResponse.json({ error: 'Data root path not found' }, { status: 500 }); + } + const imgRoot = join(dataRoot, 'images'); + + + const formData = await request.formData(); + const files = formData.getAll('files'); + + if (!files || files.length === 0) { + return NextResponse.json({ error: 'No files provided' }, { status: 400 }); + } + + // make it recursive if it doesn't exist + await mkdir(imgRoot, { recursive: true }); + const savedFiles = await Promise.all( + files.map(async (file: any) => { + const bytes = await file.arrayBuffer(); + const buffer = Buffer.from(bytes); + + const extension = file.name.split('.').pop() || 'jpg'; + + // Clean filename and ensure it's unique + const fileName = `${uuidv4()}`; // Use UUID for unique file names + const filePath = join(imgRoot, `${fileName}.${extension}`); + + await writeFile(filePath, buffer); + return filePath; + }), + ); + + return NextResponse.json({ + message: 'Files uploaded successfully', + files: savedFiles, + }); + } catch (error) { + console.error('Upload error:', error); + return NextResponse.json({ error: 'Error uploading files' }, { status: 500 }); + } +} + +// Increase payload size limit (default is 4mb) +export const config = { + api: { + bodyParser: false, + responseLimit: '50mb', + }, +}; diff --git a/src/app/api/jobs/[jobID]/delete/route.ts b/src/app/api/jobs/[jobID]/delete/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..618e33f440301495c47141bff70b99b43438c4a3 --- /dev/null +++ b/src/app/api/jobs/[jobID]/delete/route.ts @@ -0,0 +1,32 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import { getTrainingFolder } from '@/server/settings'; +import path from 'path'; +import fs from 'fs'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingRoot = await getTrainingFolder(); + const trainingFolder = path.join(trainingRoot, job.name); + + if (fs.existsSync(trainingFolder)) { + fs.rmdirSync(trainingFolder, { recursive: true }); + } + + await prisma.job.delete({ + where: { id: jobID }, + }); + + return NextResponse.json(job); +} diff --git a/src/app/api/jobs/[jobID]/files/route.ts b/src/app/api/jobs/[jobID]/files/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..575df5e5a68cc8739aac16b55f2631d267b040fe --- /dev/null +++ b/src/app/api/jobs/[jobID]/files/route.ts @@ -0,0 +1,48 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingFolder = await getTrainingFolder(); + const jobFolder = path.join(trainingFolder, job.name); + + if (!fs.existsSync(jobFolder)) { + return NextResponse.json({ files: [] }); + } + + // find all safetensors files in the job folder + let files = fs + .readdirSync(jobFolder) + .filter(file => { + return file.endsWith('.safetensors'); + }) + .map(file => { + return path.join(jobFolder, file); + }) + .sort(); + + // get the file size for each file + const fileObjects = files.map(file => { + const stats = fs.statSync(file); + return { + path: file, + size: stats.size, + }; + }); + + return NextResponse.json({ files: fileObjects }); +} diff --git a/src/app/api/jobs/[jobID]/log/route.ts b/src/app/api/jobs/[jobID]/log/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..10ccbdaac76b76ec20cead8e7f634af0d723ad8f --- /dev/null +++ b/src/app/api/jobs/[jobID]/log/route.ts @@ -0,0 +1,35 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingFolder = await getTrainingFolder(); + const jobFolder = path.join(trainingFolder, job.name); + const logPath = path.join(jobFolder, 'log.txt'); + + if (!fs.existsSync(logPath)) { + return NextResponse.json({ log: '' }); + } + let log = ''; + try { + log = fs.readFileSync(logPath, 'utf-8'); + } catch (error) { + console.error('Error reading log file:', error); + log = 'Error reading log file'; + } + return NextResponse.json({ log: log }); +} diff --git a/src/app/api/jobs/[jobID]/samples/route.ts b/src/app/api/jobs/[jobID]/samples/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..2a98a6eac1a7581243aa7adfec6da5d5a40c938c --- /dev/null +++ b/src/app/api/jobs/[jobID]/samples/route.ts @@ -0,0 +1,40 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + // setup the training + const trainingFolder = await getTrainingFolder(); + + const samplesFolder = path.join(trainingFolder, job.name, 'samples'); + if (!fs.existsSync(samplesFolder)) { + return NextResponse.json({ samples: [] }); + } + + // find all img (png, jpg, jpeg) files in the samples folder + const samples = fs + .readdirSync(samplesFolder) + .filter(file => { + return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp'); + }) + .map(file => { + return path.join(samplesFolder, file); + }) + .sort(); + + return NextResponse.json({ samples }); +} diff --git a/src/app/api/jobs/[jobID]/start/route.ts b/src/app/api/jobs/[jobID]/start/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..e26c1e499373e1aa3821f2031472ec0e0727526f --- /dev/null +++ b/src/app/api/jobs/[jobID]/start/route.ts @@ -0,0 +1,215 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import { TOOLKIT_ROOT } from '@/paths'; +import { spawn } from 'child_process'; +import path from 'path'; +import fs from 'fs'; +import os from 'os'; +import { getTrainingFolder, getHFToken } from '@/server/settings'; +const isWindows = process.platform === 'win32'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + // update job status to 'running' + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'running', + stop: false, + info: 'Starting job...', + }, + }); + + // setup the training + const trainingRoot = await getTrainingFolder(); + + const trainingFolder = path.join(trainingRoot, job.name); + if (!fs.existsSync(trainingFolder)) { + fs.mkdirSync(trainingFolder, { recursive: true }); + } + + // make the config file + const configPath = path.join(trainingFolder, '.job_config.json'); + + //log to path + const logPath = path.join(trainingFolder, 'log.txt'); + + try { + // if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num + // if the log path does not exist, create it + if (fs.existsSync(logPath)) { + const logsFolder = path.join(trainingFolder, 'logs'); + if (!fs.existsSync(logsFolder)) { + fs.mkdirSync(logsFolder, { recursive: true }); + } + + let num = 0; + while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) { + num++; + } + + fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`)); + } + } catch (e) { + console.error('Error moving log file:', e); + } + + // update the config dataset path + const jobConfig = JSON.parse(job.job_config); + jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db'); + + // write the config file + fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2)); + + let pythonPath = 'python'; + // use .venv or venv if it exists + if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) { + if (isWindows) { + pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe'); + } else { + pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python'); + } + } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) { + if (isWindows) { + pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe'); + } else { + pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python'); + } + } + + const runFilePath = path.join(TOOLKIT_ROOT, 'run.py'); + if (!fs.existsSync(runFilePath)) { + return NextResponse.json({ error: 'run.py not found' }, { status: 500 }); + } + + const additionalEnv: any = { + AITK_JOB_ID: jobID, + CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, + IS_AI_TOOLKIT_UI: '1' + }; + + // HF_TOKEN + const hfToken = await getHFToken(); + if (hfToken && hfToken.trim() !== '') { + additionalEnv.HF_TOKEN = hfToken; + } + + // Add the --log argument to the command + const args = [runFilePath, configPath, '--log', logPath]; + + try { + let subprocess; + + if (isWindows) { + // For Windows, use 'cmd.exe' to open a new command window + subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, ...args], { + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + windowsHide: false, + }); + } else { + // For non-Windows platforms + subprocess = spawn(pythonPath, args, { + detached: true, + stdio: ['ignore', 'pipe', 'pipe'], // Changed from 'ignore' to capture output + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + }); + } + + // Start monitoring in the background without blocking the response + const monitorProcess = async () => { + const startTime = Date.now(); + let errorOutput = ''; + let stdoutput = ''; + + if (subprocess.stderr) { + subprocess.stderr.on('data', data => { + errorOutput += data.toString(); + }); + subprocess.stdout.on('data', data => { + stdoutput += data.toString(); + // truncate to only get the last 500 characters + if (stdoutput.length > 500) { + stdoutput = stdoutput.substring(stdoutput.length - 500); + } + }); + } + + subprocess.on('exit', async code => { + const currentTime = Date.now(); + const duration = (currentTime - startTime) / 1000; + console.log(`Job ${jobID} exited with code ${code} after ${duration} seconds.`); + // wait for 5 seconds to give it time to stop itself. It id still has a status of running in the db, update it to stopped + await new Promise(resolve => setTimeout(resolve, 5000)); + const updatedJob = await prisma.job.findUnique({ + where: { id: jobID }, + }); + if (updatedJob?.status === 'running') { + let errorString = errorOutput; + if (errorString.trim() === '') { + errorString = stdoutput; + } + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'error', + info: `Error launching job: ${errorString.substring(0, 500)}`, + }, + }); + } + }); + + // Wait 30 seconds before releasing the process + await new Promise(resolve => setTimeout(resolve, 30000)); + // Detach the process for non-Windows systems + if (!isWindows && subprocess.unref) { + subprocess.unref(); + } + }; + + // Start the monitoring without awaiting it + monitorProcess().catch(err => { + console.error(`Error in process monitoring for job ${jobID}:`, err); + }); + + // Return the response immediately + return NextResponse.json(job); + } catch (error: any) { + // Handle any exceptions during process launch + console.error('Error launching process:', error); + + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'error', + info: `Error launching job: ${error?.message || 'Unknown error'}`, + }, + }); + + return NextResponse.json( + { + error: 'Failed to launch job process', + details: error?.message || 'Unknown error', + }, + { status: 500 }, + ); + } +} diff --git a/src/app/api/jobs/[jobID]/stop/route.ts b/src/app/api/jobs/[jobID]/stop/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..73b352dfc55664b1b689075727f7245589523005 --- /dev/null +++ b/src/app/api/jobs/[jobID]/stop/route.ts @@ -0,0 +1,23 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + // update job status to 'running' + await prisma.job.update({ + where: { id: jobID }, + data: { + stop: true, + info: 'Stopping job...', + }, + }); + + return NextResponse.json(job); +} diff --git a/src/app/api/jobs/route.ts b/src/app/api/jobs/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..8f0419b924cfa6724371712b279e89c666437eb6 --- /dev/null +++ b/src/app/api/jobs/route.ts @@ -0,0 +1,67 @@ +import { NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const id = searchParams.get('id'); + + try { + if (id) { + const job = await prisma.job.findUnique({ + where: { id }, + }); + return NextResponse.json(job); + } + + const jobs = await prisma.job.findMany({ + orderBy: { created_at: 'desc' }, + }); + return NextResponse.json({ jobs: jobs }); + } catch (error) { + console.error(error); + return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 }); + } +} + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { id, name, job_config, gpu_ids } = body; + + // Ensure gpu_ids is never null/undefined - provide default value + const safeGpuIds = gpu_ids || '0'; + + if (id) { + // Update existing training + const training = await prisma.job.update({ + where: { id }, + data: { + name, + gpu_ids: safeGpuIds, + job_config: JSON.stringify(job_config), + }, + }); + return NextResponse.json(training); + } else { + // Create new training + const training = await prisma.job.create({ + data: { + name, + gpu_ids: safeGpuIds, + job_config: JSON.stringify(job_config), + }, + }); + return NextResponse.json(training); + } + } catch (error: any) { + if (error.code === 'P2002') { + // Handle unique constraint violation, 409=Conflict + return NextResponse.json({ error: 'Job name already exists' }, { status: 409 }); + } + console.error(error); + // Handle other errors + return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 }); + } +} diff --git a/src/app/api/settings/route.ts b/src/app/api/settings/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..62528cdd0b6a7de39c7ade3e96ea9f0b1ec2a226 --- /dev/null +++ b/src/app/api/settings/route.ts @@ -0,0 +1,59 @@ +import { NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths'; +import { flushCache } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET() { + try { + const settings = await prisma.settings.findMany(); + const settingsObject = settings.reduce((acc: any, setting) => { + acc[setting.key] = setting.value; + return acc; + }, {}); + // if TRAINING_FOLDER is not set, use default + if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') { + settingsObject.TRAINING_FOLDER = defaultTrainFolder; + } + // if DATASETS_FOLDER is not set, use default + if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') { + settingsObject.DATASETS_FOLDER = defaultDatasetsFolder; + } + return NextResponse.json(settingsObject); + } catch (error) { + return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 }); + } +} + +export async function POST(request: Request) { + try { + const body = await request.json(); + const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body; + + // Upsert both settings + await Promise.all([ + prisma.settings.upsert({ + where: { key: 'HF_TOKEN' }, + update: { value: HF_TOKEN }, + create: { key: 'HF_TOKEN', value: HF_TOKEN }, + }), + prisma.settings.upsert({ + where: { key: 'TRAINING_FOLDER' }, + update: { value: TRAINING_FOLDER }, + create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER }, + }), + prisma.settings.upsert({ + where: { key: 'DATASETS_FOLDER' }, + update: { value: DATASETS_FOLDER }, + create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER }, + }), + ]); + + flushCache(); + + return NextResponse.json({ success: true }); + } catch (error) { + return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 }); + } +} diff --git a/src/app/api/zip/route.ts b/src/app/api/zip/route.ts new file mode 100644 index 0000000000000000000000000000000000000000..fc4b946da5f6265d4d193849bf218fea41ea6e01 --- /dev/null +++ b/src/app/api/zip/route.ts @@ -0,0 +1,78 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import fsp from 'fs/promises'; +import path from 'path'; +import archiver from 'archiver'; +import { getTrainingFolder } from '@/server/settings'; + +export const runtime = 'nodejs'; // ensure Node APIs are available +export const dynamic = 'force-dynamic'; // long-running, non-cached + +type PostBody = { + zipTarget: 'samples'; //only samples for now + jobName: string; +}; + +async function resolveSafe(p: string) { + // resolve symlinks + normalize + return await fsp.realpath(p); +} + +export async function POST(request: NextRequest) { + try { + const body = (await request.json()) as PostBody; + if (!body || !body.jobName) { + return NextResponse.json({ error: 'jobName is required' }, { status: 400 }); + } + + const trainingRoot = await resolveSafe(await getTrainingFolder()); + const folderPath = await resolveSafe(path.join(trainingRoot, body.jobName, 'samples')); + const outputPath = path.resolve(trainingRoot, body.jobName, 'samples.zip'); + + // Must be a directory + let stat: fs.Stats; + try { + stat = await fsp.stat(folderPath); + } catch { + return new NextResponse('Folder not found', { status: 404 }); + } + if (!stat.isDirectory()) { + return new NextResponse('Not a directory', { status: 400 }); + } + + // delete current one if it exists + if (fs.existsSync(outputPath)) { + await fsp.unlink(outputPath); + } + + // Create write stream & archive + await new Promise((resolve, reject) => { + const output = fs.createWriteStream(outputPath); + const archive = archiver('zip', { zlib: { level: 9 } }); + + output.on('close', () => resolve()); + output.on('error', reject); + archive.on('error', reject); + + archive.pipe(output); + + // Add the directory contents (place them under the folder's base name in the zip) + const rootName = path.basename(folderPath); + archive.directory(folderPath, rootName); + + archive.finalize().catch(reject); + }); + + // Return the absolute path so your existing /api/files/[...filePath] can serve it + // Example download URL (client-side): `/api/files/${encodeURIComponent(resolvedOutPath)}` + return NextResponse.json({ + ok: true, + zipPath: outputPath, + fileName: path.basename(outputPath), + }); + } catch (err) { + console.error('Zip error:', err); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/src/app/apple-icon.png b/src/app/apple-icon.png new file mode 100644 index 0000000000000000000000000000000000000000..595cb880e5cff0ab9605c2ef76dba8ebb7e7fc62 Binary files /dev/null and b/src/app/apple-icon.png differ diff --git a/src/app/dashboard/page.tsx b/src/app/dashboard/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..45d5596afc5831c419579afafdfe3cd515c4e3d0 --- /dev/null +++ b/src/app/dashboard/page.tsx @@ -0,0 +1,85 @@ +'use client'; + +import JobsTable from '@/components/JobsTable'; +import { TopBar, MainContent } from '@/components/layout'; +import Link from 'next/link'; +import { useAuth } from '@/contexts/AuthContext'; +import HFLoginButton from '@/components/HFLoginButton'; + +export default function Dashboard() { + const { status: authStatus, namespace } = useAuth(); + const isAuthenticated = authStatus === 'authenticated'; + + return ( + <> + +
+

Dashboard

+
+
+ + +
+
+

+ {isAuthenticated ? `Welcome back, ${namespace || 'creator'}!` : 'Welcome to Ostris AI Toolkit'} +

+

+ {isAuthenticated + ? 'You are signed in with Hugging Face and can manage jobs, datasets, and submissions.' + : 'Authenticate with Hugging Face or add a personal access token to create jobs, upload datasets, and launch training.'} +

+
+ {isAuthenticated ? ( +
+ + Create a Training Job + + + Manage Datasets + + + Settings + +
+ ) : ( +
+ + + Or manage tokens in Settings + +
+ )} +
+ +
+
+

Active Jobs

+
+ View All +
+
+ {isAuthenticated ? ( + + ) : ( +
+ Sign in with Hugging Face or add an access token in Settings to view and manage jobs. +
+ )} +
+
+ + ); +} diff --git a/src/app/datasets/[datasetName]/page.tsx b/src/app/datasets/[datasetName]/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..776eeb525fdacf0621828734ffdd79bbd21697a8 --- /dev/null +++ b/src/app/datasets/[datasetName]/page.tsx @@ -0,0 +1,190 @@ +'use client'; + +import { useEffect, useState, use, useMemo } from 'react'; +import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu'; +import { FaChevronLeft } from 'react-icons/fa'; +import DatasetImageCard from '@/components/DatasetImageCard'; +import { Button } from '@headlessui/react'; +import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal'; +import { TopBar, MainContent } from '@/components/layout'; +import { apiClient } from '@/utils/api'; +import FullscreenDropOverlay from '@/components/FullscreenDropOverlay'; +import { useRouter } from 'next/navigation'; +import { usingBrowserDb } from '@/utils/env'; +import { hasUserDataset } from '@/utils/storage/datasetStorage'; +import { useAuth } from '@/contexts/AuthContext'; +import HFLoginButton from '@/components/HFLoginButton'; +import Link from 'next/link'; + +export default function DatasetPage({ params }: { params: { datasetName: string } }) { + const [imgList, setImgList] = useState<{ img_path: string }[]>([]); + const usableParams = use(params as any) as { datasetName: string }; + const datasetName = usableParams.datasetName; + const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle'); + const router = useRouter(); + const { status: authStatus } = useAuth(); + const isAuthenticated = authStatus === 'authenticated'; + const hasDatasetEntry = !usingBrowserDb || hasUserDataset(datasetName); + const allowAccess = hasDatasetEntry && isAuthenticated; + + const refreshImageList = (dbName: string) => { + setStatus('loading'); + console.log('Fetching images for dataset:', dbName); + apiClient + .post('/api/datasets/listImages', { datasetName: dbName }) + .then((res: any) => { + const data = res.data; + console.log('Images:', data.images); + // sort + data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path)); + setImgList(data.images); + setStatus('success'); + }) + .catch(error => { + console.error('Error fetching images:', error); + setStatus('error'); + }); + }; + useEffect(() => { + if (!datasetName) { + return; + } + + if (!isAuthenticated) { + return; + } + + if (!hasDatasetEntry) { + setImgList([]); + setStatus('error'); + router.replace('/datasets'); + return; + } + + refreshImageList(datasetName); + }, [datasetName, hasDatasetEntry, isAuthenticated, router]); + + if (!allowAccess) { + return ( + <> + +
+ +
+
+

Dataset: {datasetName}

+
+
+
+ +
+

You need to sign in with Hugging Face or provide a valid token to view this dataset.

+
+ + + Manage authentication in Settings + +
+
+
+ + ); + } + + const PageInfoContent = useMemo(() => { + let icon = null; + let text = ''; + let subtitle = ''; + let showIt = false; + let bgColor = ''; + let textColor = ''; + let iconColor = ''; + + if (status == 'loading') { + icon = ; + text = 'Loading Images'; + subtitle = 'Please wait while we fetch your dataset images...'; + showIt = true; + bgColor = 'bg-gray-50 dark:bg-gray-800/50'; + textColor = 'text-gray-900 dark:text-gray-100'; + iconColor = 'text-gray-500 dark:text-gray-400'; + } + if (status == 'error') { + icon = ; + text = 'Error Loading Images'; + subtitle = 'There was a problem fetching the images. Please try refreshing the page.'; + showIt = true; + bgColor = 'bg-red-50 dark:bg-red-950/20'; + textColor = 'text-red-900 dark:text-red-100'; + iconColor = 'text-red-600 dark:text-red-400'; + } + if (status == 'success' && imgList.length === 0) { + icon = ; + text = 'No Images Found'; + subtitle = 'This dataset is empty. Click "Add Images" to get started.'; + showIt = true; + bgColor = 'bg-gray-50 dark:bg-gray-800/50'; + textColor = 'text-gray-900 dark:text-gray-100'; + iconColor = 'text-gray-500 dark:text-gray-400'; + } + + if (!showIt) return null; + + return ( +
+
{icon}
+

{text}

+

{subtitle}

+
+ ); + }, [status, imgList.length]); + + return ( + <> + {/* Fixed top bar */} + +
+ +
+
+

Dataset: {datasetName}

+
+
+
+ +
+
+ + {PageInfoContent} + {status === 'success' && imgList.length > 0 && ( +
+ {imgList.map(img => ( + refreshImageList(datasetName)} + /> + ))} +
+ )} +
+ + refreshImageList(datasetName)} + /> + + ); +} diff --git a/src/app/datasets/page.tsx b/src/app/datasets/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..eec8310f9ba6f38f5eca345a0b6400e754241a64 --- /dev/null +++ b/src/app/datasets/page.tsx @@ -0,0 +1,217 @@ +'use client'; + +import { useState } from 'react'; +import { Modal } from '@/components/Modal'; +import Link from 'next/link'; +import { TextInput } from '@/components/formInputs'; +import useDatasetList from '@/hooks/useDatasetList'; +import { Button } from '@headlessui/react'; +import { FaRegTrashAlt } from 'react-icons/fa'; +import { openConfirm } from '@/components/ConfirmModal'; +import { TopBar, MainContent } from '@/components/layout'; +import UniversalTable, { TableColumn } from '@/components/UniversalTable'; +import { apiClient } from '@/utils/api'; +import { useRouter } from 'next/navigation'; +import { usingBrowserDb } from '@/utils/env'; +import { addUserDataset, removeUserDataset } from '@/utils/storage/datasetStorage'; +import { useAuth } from '@/contexts/AuthContext'; +import HFLoginButton from '@/components/HFLoginButton'; + +export default function Datasets() { + const router = useRouter(); + const { datasets, status, refreshDatasets } = useDatasetList(); + const [newDatasetName, setNewDatasetName] = useState(''); + const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false); + const { status: authStatus } = useAuth(); + const isAuthenticated = authStatus === 'authenticated'; + + // Transform datasets array into rows with objects + const tableRows = datasets.map(dataset => ({ + name: dataset, + actions: dataset, // Pass full dataset name for actions + })); + + const columns: TableColumn[] = [ + { + title: 'Dataset Name', + key: 'name', + render: row => ( + + {row.name} + + ), + }, + { + title: 'Actions', + key: 'actions', + className: 'w-20 text-right', + render: row => ( + + ), + }, + ]; + + const handleDeleteDataset = (datasetName: string) => { + openConfirm({ + title: 'Delete Dataset', + message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`, + type: 'warning', + confirmText: 'Delete', + onConfirm: () => { + apiClient + .post('/api/datasets/delete', { name: datasetName }) + .then(() => { + console.log('Dataset deleted:', datasetName); + if (usingBrowserDb) { + removeUserDataset(datasetName); + } + refreshDatasets(); + }) + .catch(error => { + console.error('Error deleting dataset:', error); + }); + }, + }); + }; + + const handleCreateDataset = async (e: React.FormEvent) => { + e.preventDefault(); + if (!isAuthenticated) { + return; + } + try { + const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data); + console.log('New dataset created:', data); + if (usingBrowserDb && data?.name) { + addUserDataset(data.name, data?.path || ''); + } + refreshDatasets(); + setNewDatasetName(''); + setIsNewDatasetModalOpen(false); + } catch (error) { + console.error('Error creating new dataset:', error); + } + }; + + const openNewDatasetModal = () => { + if (!isAuthenticated) { + return; + } + openConfirm({ + title: 'New Dataset', + message: 'Enter the name of the new dataset:', + type: 'info', + confirmText: 'Create', + inputTitle: 'Dataset Name', + onConfirm: async (name?: string) => { + if (!name) { + console.error('Dataset name is required.'); + return; + } + if (!isAuthenticated) { + return; + } + try { + const data = await apiClient.post('/api/datasets/create', { name }).then(res => res.data); + console.log('New dataset created:', data); + if (usingBrowserDb && data?.name) { + addUserDataset(data.name, data?.path || ''); + } + if (data.name) { + router.push(`/datasets/${data.name}`); + } else { + refreshDatasets(); + } + } catch (error) { + console.error('Error creating new dataset:', error); + } + }, + }); + }; + + return ( + <> + +
+

Datasets

+
+
+
+ {isAuthenticated ? ( + + ) : ( + + Sign in to add datasets + + )} +
+
+ + + {isAuthenticated ? ( + + ) : ( +
+

Sign in with Hugging Face or add an access token to manage datasets.

+
+ + + Manage authentication in Settings + +
+
+ )} +
+ + setIsNewDatasetModalOpen(false)} + title="New Dataset" + size="md" + > +
+
+
+ This will create a new folder with the name below in your dataset folder. +
+
+ setNewDatasetName(value)} /> +
+ +
+ + +
+
+
+
+ + ); +} diff --git a/src/app/favicon.ico b/src/app/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..a20b629a5996a0b62c038bf356f1e28eab9bdb99 Binary files /dev/null and b/src/app/favicon.ico differ diff --git a/src/app/globals.css b/src/app/globals.css new file mode 100644 index 0000000000000000000000000000000000000000..890dc5bc7b9125662f38d11d758350ba5a80f744 --- /dev/null +++ b/src/app/globals.css @@ -0,0 +1,72 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +:root { + --background: #ffffff; + --foreground: #171717; +} + +@media (prefers-color-scheme: dark) { + :root { + --background: #0a0a0a; + --foreground: #ededed; + } +} + +body { + color: var(--foreground); + background: var(--background); + font-family: Arial, Helvetica, sans-serif; +} + +@layer components { + /* control */ + .aitk-react-select-container .aitk-react-select__control { + @apply flex w-full h-8 min-h-0 px-0 text-sm bg-gray-800 border border-gray-700 rounded-sm hover:border-gray-600 items-center; + } + + /* selected label */ + .aitk-react-select-container .aitk-react-select__single-value { + @apply flex-1 min-w-0 truncate text-sm text-neutral-200; + } + + /* invisible input (keeps focus & typing, never wraps) */ + .aitk-react-select-container .aitk-react-select__input-container { + @apply text-neutral-200; + } + + /* focus */ + .aitk-react-select-container .aitk-react-select__control--is-focused { + @apply ring-2 ring-gray-600 border-transparent hover:border-transparent shadow-none; + } + + /* menu */ + .aitk-react-select-container .aitk-react-select__menu { + @apply bg-gray-800 border border-gray-700; + } + + /* options */ + .aitk-react-select-container .aitk-react-select__option { + @apply text-sm text-neutral-200 bg-gray-800 hover:bg-gray-700; + } + + /* indicator separator */ + .aitk-react-select-container .aitk-react-select__indicator-separator { + @apply bg-gray-600; + } + + /* indicators */ + .aitk-react-select-container .aitk-react-select__indicators, + .aitk-react-select-container .aitk-react-select__indicator { + @apply py-0 flex items-center; + } + + /* placeholder */ + .aitk-react-select-container .aitk-react-select__placeholder { + @apply text-sm text-neutral-200; + } +} + + + diff --git a/src/app/icon.png b/src/app/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..8bcfbf80f1f08f9b1f6678914370f00a105a37b2 Binary files /dev/null and b/src/app/icon.png differ diff --git a/src/app/icon.svg b/src/app/icon.svg new file mode 100644 index 0000000000000000000000000000000000000000..2689ae5393931a68144db7d92555343aeef0155c --- /dev/null +++ b/src/app/icon.svg @@ -0,0 +1,3 @@ + \ No newline at end of file diff --git a/src/app/jobs/[jobID]/page.tsx b/src/app/jobs/[jobID]/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ae001714d36deec0f26e52f0d1542f4684a7ef7e --- /dev/null +++ b/src/app/jobs/[jobID]/page.tsx @@ -0,0 +1,147 @@ +'use client'; + +import { useState, use } from 'react'; +import { FaChevronLeft } from 'react-icons/fa'; +import { Button } from '@headlessui/react'; +import { TopBar, MainContent } from '@/components/layout'; +import useJob from '@/hooks/useJob'; +import SampleImages, {SampleImagesMenu} from '@/components/SampleImages'; +import JobOverview from '@/components/JobOverview'; +import { redirect } from 'next/navigation'; +import { useAuth } from '@/contexts/AuthContext'; +import HFLoginButton from '@/components/HFLoginButton'; +import Link from 'next/link'; +import JobActionBar from '@/components/JobActionBar'; +import JobConfigViewer from '@/components/JobConfigViewer'; +import { JobRecord } from '@/types'; + +type PageKey = 'overview' | 'samples' | 'config'; + +interface Page { + name: string; + value: PageKey; + component: React.ComponentType<{ job: JobRecord }>; + menuItem?: React.ComponentType<{ job?: JobRecord | null }> | null; + mainCss?: string; +} + +const pages: Page[] = [ + { + name: 'Overview', + value: 'overview', + component: JobOverview, + mainCss: 'pt-24', + }, + { + name: 'Samples', + value: 'samples', + component: SampleImages, + menuItem: SampleImagesMenu, + mainCss: 'pt-24', + }, + { + name: 'Config File', + value: 'config', + component: JobConfigViewer, + mainCss: 'pt-[80px] px-0 pb-0', + }, +]; + +export default function JobPage({ params }: { params: { jobID: string } }) { + const usableParams = use(params as any) as { jobID: string }; + const jobID = usableParams.jobID; + const { job, status, refreshJob } = useJob(jobID, 5000); + const [pageKey, setPageKey] = useState('overview'); + const { status: authStatus } = useAuth(); + const isAuthenticated = authStatus === 'authenticated'; + + const page = pages.find(p => p.value === pageKey); + + if (!isAuthenticated) { + return ( + <> + +
+ +
+
+

Job Details

+
+
+
+ +
+

Sign in with Hugging Face or add an access token to view job details.

+
+ + + Manage authentication in Settings + +
+
+
+ + ); + } + + return ( + <> + {/* Fixed top bar */} + +
+ +
+
+

Job: {job?.name}

+
+
+ {job && ( + { + redirect('/jobs'); + }} + /> + )} +
+ page.value === pageKey)?.mainCss}> + {status === 'loading' && job == null &&

Loading...

} + {status === 'error' && job == null &&

Error fetching job

} + {job && ( + <> + {pages.map(page => { + const Component = page.component; + return page.value === pageKey ? : null; + })} + + )} +
+
+ {pages.map(page => ( + + ))} + { + page?.menuItem && ( + <> +
+
+ + + ) + } +
+ + ); +} diff --git a/src/app/jobs/new/AdvancedJob.tsx b/src/app/jobs/new/AdvancedJob.tsx new file mode 100644 index 0000000000000000000000000000000000000000..bccc4da22a57660ae23e0882f641362f1dfd4dec --- /dev/null +++ b/src/app/jobs/new/AdvancedJob.tsx @@ -0,0 +1,146 @@ +'use client'; +import { useEffect, useState, useRef } from 'react'; +import { JobConfig } from '@/types'; +import YAML from 'yaml'; +import Editor, { OnMount } from '@monaco-editor/react'; +import type { editor } from 'monaco-editor'; +import { SettingsData } from '@/types'; +import { migrateJobConfig } from './jobConfig'; + +type Props = { + jobConfig: JobConfig; + setJobConfig: (value: any, key?: string) => void; + status: 'idle' | 'saving' | 'success' | 'error'; + handleSubmit: (event: React.FormEvent) => void; + runId: string | null; + gpuIDs: string | null; + setGpuIDs: (value: string | null) => void; + gpuList: any; + datasetOptions: any; + settings: SettingsData; +}; + +const isDev = process.env.NODE_ENV === 'development'; + +const yamlConfig: YAML.DocumentOptions & + YAML.SchemaOptions & + YAML.ParseOptions & + YAML.CreateNodeOptions & + YAML.ToStringOptions = { + indent: 2, + lineWidth: 999999999999, + defaultStringType: 'QUOTE_DOUBLE', + defaultKeyType: 'PLAIN', + directives: true, +}; + +export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) { + const [editorValue, setEditorValue] = useState(''); + const lastJobConfigUpdateStringRef = useRef(''); + const editorRef = useRef(null); + + // Track if the editor has been mounted + const isEditorMounted = useRef(false); + + // Handler for editor mounting + const handleEditorDidMount: OnMount = editor => { + editorRef.current = editor; + isEditorMounted.current = true; + + // Initial content setup + try { + const yamlContent = YAML.stringify(jobConfig, yamlConfig); + setEditorValue(yamlContent); + lastJobConfigUpdateStringRef.current = JSON.stringify(jobConfig); + } catch (e) { + console.warn(e); + } + }; + + useEffect(() => { + const lastUpdate = lastJobConfigUpdateStringRef.current; + const currentUpdate = JSON.stringify(jobConfig); + + // Skip if no changes or editor not yet mounted + if (lastUpdate === currentUpdate || !isEditorMounted.current) { + return; + } + + try { + // Preserve cursor position and selection + const editor = editorRef.current; + if (editor) { + // Save current editor state + const position = editor.getPosition(); + const selection = editor.getSelection(); + const scrollTop = editor.getScrollTop(); + + // Update content + const yamlContent = YAML.stringify(jobConfig, yamlConfig); + + // Only update if the content is actually different + if (yamlContent !== editor.getValue()) { + // Set value directly on the editor model instead of using React state + editor.getModel()?.setValue(yamlContent); + + // Restore cursor position and selection + if (position) editor.setPosition(position); + if (selection) editor.setSelection(selection); + editor.setScrollTop(scrollTop); + } + + lastJobConfigUpdateStringRef.current = currentUpdate; + } + } catch (e) { + console.warn(e); + } + }, [jobConfig]); + + const handleChange = (value: string | undefined) => { + if (value === undefined) return; + + try { + const parsed = YAML.parse(value); + // Don't update jobConfig if the change came from the editor itself + // to avoid a circular update loop + if (JSON.stringify(parsed) !== lastJobConfigUpdateStringRef.current) { + lastJobConfigUpdateStringRef.current = JSON.stringify(parsed); + + // We have to ensure certain things are always set + try { + parsed.config.process[0].type = 'ui_trainer'; + parsed.config.process[0].sqlite_db_path = './aitk_db.db'; + parsed.config.process[0].training_folder = settings.TRAINING_FOLDER; + parsed.config.process[0].device = 'cuda'; + parsed.config.process[0].performance_log_every = 10; + } catch (e) { + console.warn(e); + } + migrateJobConfig(parsed); + setJobConfig(parsed); + } + } catch (e) { + // Don't update on parsing errors + console.warn(e); + } + }; + + return ( + <> + + + ); +} diff --git a/src/app/jobs/new/SimpleJob.tsx b/src/app/jobs/new/SimpleJob.tsx new file mode 100644 index 0000000000000000000000000000000000000000..080c383de00f4858199e0937cbca92385910a598 --- /dev/null +++ b/src/app/jobs/new/SimpleJob.tsx @@ -0,0 +1,973 @@ +'use client'; +import { useMemo, useState } from 'react'; +import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options'; +import { defaultDatasetConfig } from './jobConfig'; +import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; +import { objectCopy } from '@/utils/basic'; +import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; +import Card from '@/components/Card'; +import { X } from 'lucide-react'; +import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; +import {FlipHorizontal2, FlipVertical2} from "lucide-react"; +import HFJobsWorkflow from '@/components/HFJobsWorkflow'; + +type Props = { + jobConfig: JobConfig; + setJobConfig: (value: any, key: string) => void; + status: 'idle' | 'saving' | 'success' | 'error'; + handleSubmit: (event: React.FormEvent) => void; + runId: string | null; + gpuIDs: string | null; + setGpuIDs: (value: string | null) => void; + gpuList: any; + datasetOptions: any; + trainingBackend?: 'local' | 'hf-jobs'; + setTrainingBackend?: (backend: 'local' | 'hf-jobs') => void; + hfJobSubmitted?: boolean; + onHFJobComplete?: (jobId: string, localJobId?: string) => void; + forceHFBackend?: boolean; +}; + +const isDev = process.env.NODE_ENV === 'development'; + +export default function SimpleJob({ + jobConfig, + setJobConfig, + handleSubmit, + status, + runId, + gpuIDs, + setGpuIDs, + gpuList, + datasetOptions, + trainingBackend: parentTrainingBackend, + setTrainingBackend: parentSetTrainingBackend, + hfJobSubmitted, + onHFJobComplete, + forceHFBackend = false, +}: Props) { + const [localTrainingBackend, setLocalTrainingBackend] = useState(forceHFBackend ? 'hf-jobs' : 'local'); + const trainingBackend = parentTrainingBackend || localTrainingBackend; + const setTrainingBackend = forceHFBackend + ? (_: 'local' | 'hf-jobs') => undefined + : parentSetTrainingBackend || setLocalTrainingBackend; + const backendOptions = forceHFBackend + ? [{ value: 'hf-jobs', label: 'HF Jobs (Cloud)' }] + : [ + { value: 'local', label: 'Local GPU' }, + { value: 'hf-jobs', label: 'HF Jobs (Cloud)' }, + ]; + const modelArch = useMemo(() => { + return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; + }, [jobConfig.config.process[0].model.arch]); + + const isVideoModel = !!(modelArch?.group === 'video'); + + const numTopCards = useMemo(() => { + let count = 4; // job settings, model config, target config, save config + if (modelArch?.additionalSections?.includes('model.multistage')) { + count += 1; // add multistage card + } + if (!modelArch?.disableSections?.includes('model.quantize')) { + count += 1; // add quantization card + } + return count; + + }, [modelArch]); + + let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; + + if (numTopCards == 5) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; + } + if (numTopCards == 6) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; + } + + const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { + const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; + if (!hasARA) { + return quantizationOptions; + } + let newQuantizationOptions = [ + { + label: 'Standard', + options: [quantizationOptions[0], quantizationOptions[1]], + }, + ]; + + // add ARAs if they exist for the model + let ARAs: SelectOption[] = []; + if (modelArch.accuracyRecoveryAdapters) { + for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { + ARAs.push({ value, label }); + } + } + if (ARAs.length > 0) { + newQuantizationOptions.push({ + label: 'Accuracy Recovery Adapters', + options: ARAs, + }); + } + + let additionalQuantizationOptions: SelectOption[] = []; + // add the quantization options if they are not already included + for (let i = 2; i < quantizationOptions.length; i++) { + const option = quantizationOptions[i]; + additionalQuantizationOptions.push(option); + } + if (additionalQuantizationOptions.length > 0) { + newQuantizationOptions.push({ + label: 'Additional Quantization Options', + options: additionalQuantizationOptions, + }); + } + return newQuantizationOptions; + }, [modelArch]); + + return ( + <> +
+
+ + setJobConfig(value, 'config.name')} + placeholder="Enter training name" + disabled={runId !== null} + required + /> + { + setTrainingBackend(value); + }} + options={backendOptions} + disabled={forceHFBackend} + /> + {trainingBackend === 'local' && ( + setGpuIDs(value)} + options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} + /> + )} + { + if (value?.trim() === '') { + value = null; + } + setJobConfig(value, 'config.process[0].trigger_word'); + }} + placeholder="" + required + /> + {trainingBackend === 'hf-jobs' && ( +
+

+ {hfJobSubmitted + ? '✓ HF Job already submitted! You can modify settings and resubmit if needed.' + : '⏳ HF Job ready for submission. Submit to the cloud below.' + } +

+
+ )} +
+ + {/* Model Configuration Section */} + + { + const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch); + if (!currentArch || currentArch.name === value) { + return; + } + // update the defaults when a model is selected + const newArch = modelArchs.find(model => model.name === value); + + // update vram setting + if (!newArch?.additionalSections?.includes('model.low_vram')) { + setJobConfig(false, 'config.process[0].model.low_vram'); + } + + // revert defaults from previous model + for (const key in currentArch.defaults) { + setJobConfig(currentArch.defaults[key][1], key); + } + + if (newArch?.defaults) { + for (const key in newArch.defaults) { + setJobConfig(newArch.defaults[key][0], key); + } + } + // set new model + setJobConfig(value, 'config.process[0].model.arch'); + + // update datasets + const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; + const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; + const controls = newArch?.controls ?? []; + const datasets = jobConfig.config.process[0].datasets.map(dataset => { + const newDataset = objectCopy(dataset); + newDataset.controls = controls; + if (!hasControlPath) { + newDataset.control_path = null; // reset control path if not applicable + } + if (!hasNumFrames) { + newDataset.num_frames = 1; // reset num_frames if not applicable + } + return newDataset; + }); + setJobConfig(datasets, 'config.process[0].datasets'); + + // update samples + const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false; + const samples = jobConfig.config.process[0].sample.samples.map(sample => { + const newSample = objectCopy(sample); + if (!hasSampleCtrlImg) { + delete newSample.ctrl_img; // remove ctrl_img if not applicable + } + return newSample; + }); + setJobConfig(samples, 'config.process[0].sample.samples'); + }} + options={groupedModelOptions} + /> + { + if (value?.trim() === '') { + value = null; + } + setJobConfig(value, 'config.process[0].model.name_or_path'); + }} + placeholder="" + required + /> + {modelArch?.additionalSections?.includes('model.low_vram') && ( + + setJobConfig(value, 'config.process[0].model.low_vram')} + /> + + )} + + {modelArch?.disableSections?.includes('model.quantize') ? null : ( + + { + if (value === '') { + setJobConfig(false, 'config.process[0].model.quantize'); + value = defaultQtype; + } else { + setJobConfig(true, 'config.process[0].model.quantize'); + } + setJobConfig(value, 'config.process[0].model.qtype'); + }} + options={transformerQuantizationOptions} + /> + { + if (value === '') { + setJobConfig(false, 'config.process[0].model.quantize_te'); + value = defaultQtype; + } else { + setJobConfig(true, 'config.process[0].model.quantize_te'); + } + setJobConfig(value, 'config.process[0].model.qtype_te'); + }} + options={quantizationOptions} + /> + + )} + {modelArch?.additionalSections?.includes('model.multistage') && ( + + + setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')} + /> + setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')} + /> + + setJobConfig(value, 'config.process[0].train.switch_boundary_every')} + placeholder="eg. 1" + docKey={'train.switch_boundary_every'} + min={1} + required + /> + + )} + + setJobConfig(value, 'config.process[0].network.type')} + options={[ + { value: 'lora', label: 'LoRA' }, + { value: 'lokr', label: 'LoKr' }, + ]} + /> + {jobConfig.config.process[0].network?.type == 'lokr' && ( + setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')} + options={[ + { value: '-1', label: 'Auto' }, + { value: '4', label: '4' }, + { value: '8', label: '8' }, + { value: '16', label: '16' }, + { value: '32', label: '32' }, + ]} + /> + )} + {jobConfig.config.process[0].network?.type == 'lora' && ( + <> + { + console.log('onChange', value); + setJobConfig(value, 'config.process[0].network.linear'); + setJobConfig(value, 'config.process[0].network.linear_alpha'); + }} + placeholder="eg. 16" + min={0} + max={1024} + required + /> + {modelArch?.disableSections?.includes('network.conv') ? null : ( + { + console.log('onChange', value); + setJobConfig(value, 'config.process[0].network.conv'); + setJobConfig(value, 'config.process[0].network.conv_alpha'); + }} + placeholder="eg. 16" + min={0} + max={1024} + /> + )} + + )} + + + setJobConfig(value, 'config.process[0].save.dtype')} + options={[ + { value: 'bf16', label: 'BF16' }, + { value: 'fp16', label: 'FP16' }, + { value: 'fp32', label: 'FP32' }, + ]} + /> + setJobConfig(value, 'config.process[0].save.save_every')} + placeholder="eg. 250" + min={1} + required + /> + setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')} + placeholder="eg. 4" + min={1} + required + /> + +
+
+ +
+
+ setJobConfig(value, 'config.process[0].train.batch_size')} + placeholder="eg. 4" + min={1} + required + /> + setJobConfig(value, 'config.process[0].train.gradient_accumulation')} + placeholder="eg. 1" + min={1} + required + /> + setJobConfig(value, 'config.process[0].train.steps')} + placeholder="eg. 2000" + min={1} + required + /> +
+
+ setJobConfig(value, 'config.process[0].train.optimizer')} + options={[ + { value: 'adamw8bit', label: 'AdamW8Bit' }, + { value: 'adafactor', label: 'Adafactor' }, + ]} + /> + setJobConfig(value, 'config.process[0].train.lr')} + placeholder="eg. 0.0001" + min={0} + required + /> + setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')} + placeholder="eg. 0.0001" + min={0} + required + /> +
+
+ {modelArch?.disableSections?.includes('train.timestep_type') ? null : ( + setJobConfig(value, 'config.process[0].train.timestep_type')} + options={[ + { value: 'sigmoid', label: 'Sigmoid' }, + { value: 'linear', label: 'Linear' }, + { value: 'shift', label: 'Shift' }, + { value: 'weighted', label: 'Weighted' }, + ]} + /> + )} + setJobConfig(value, 'config.process[0].train.content_or_style')} + options={[ + { value: 'balanced', label: 'Balanced' }, + { value: 'content', label: 'High Noise' }, + { value: 'style', label: 'Low Noise' }, + ]} + /> + setJobConfig(value, 'config.process[0].train.noise_scheduler')} + options={[ + { value: 'flowmatch', label: 'FlowMatch' }, + { value: 'ddpm', label: 'DDPM' }, + ]} + /> +
+
+ + setJobConfig(value, 'config.process[0].train.ema_config.use_ema')} + /> + + {jobConfig.config.process[0].train.ema_config?.use_ema && ( + setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')} + placeholder="eg. 0.99" + min={0} + /> + )} + + + { + setJobConfig(value, 'config.process[0].train.unload_text_encoder'); + if (value) { + setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); + } + }} + /> + { + setJobConfig(value, 'config.process[0].train.cache_text_embeddings'); + if (value) { + setJobConfig(false, 'config.process[0].train.unload_text_encoder'); + } + }} + /> + +
+
+ + setJobConfig(value, 'config.process[0].train.diff_output_preservation')} + /> + + {jobConfig.config.process[0].train.diff_output_preservation && ( + <> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') + } + placeholder="eg. 1.0" + min={0} + /> + setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} + placeholder="eg. woman" + /> + + )} +
+
+
+
+
+ + <> + {jobConfig.config.process[0].datasets.map((dataset, i) => ( +
+ +

Dataset {i + 1}

+
+
+ setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} + options={datasetOptions} + /> + {modelArch?.additionalSections?.includes('datasets.control_path') && ( + + setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + )} + setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)} + placeholder="eg. 1.0" + /> +
+
+ setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)} + placeholder="eg. A photo of a cat" + /> + setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)} + placeholder="eg. 0.05" + min={0} + required + /> + {modelArch?.additionalSections?.includes('datasets.num_frames') && ( + setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)} + placeholder="eg. 41" + min={1} + required + /> + )} +
+
+ + + setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`) + } + /> + setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} + /> + {modelArch?.additionalSections?.includes('datasets.do_i2v') && ( + setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} + docKey="datasets.do_i2v" + /> + )} + + + Flip X } + checked={dataset.flip_x || false} + onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)} + /> + Flip Y } + checked={dataset.flip_y || false} + onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)} + /> + +
+
+ +
+ {[ + [256, 512, 768], + [1024, 1280, 1536], + ].map(resGroup => ( +
+ {resGroup.map(res => ( + { + const resolutions = dataset.resolution.includes(res) + ? dataset.resolution.filter(r => r !== res) + : [...dataset.resolution, res]; + setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`); + }} + /> + ))} +
+ ))} +
+
+
+
+
+ ))} + + +
+
+
+ +
+
+ setJobConfig(value, 'config.process[0].sample.sample_every')} + placeholder="eg. 250" + min={1} + required + /> + setJobConfig(value, 'config.process[0].sample.sampler')} + options={[ + { value: 'flowmatch', label: 'FlowMatch' }, + { value: 'ddpm', label: 'DDPM' }, + ]} + /> + setJobConfig(value, 'config.process[0].sample.guidance_scale')} + placeholder="eg. 1.0" + className="pt-2" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.sample_steps')} + placeholder="eg. 1" + className="pt-2" + min={1} + required + /> +
+
+ setJobConfig(value, 'config.process[0].sample.width')} + placeholder="eg. 1024" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.height')} + placeholder="eg. 1024" + className="pt-2" + min={0} + required + /> + {isVideoModel && ( +
+ setJobConfig(value, 'config.process[0].sample.num_frames')} + placeholder="eg. 0" + className="pt-2" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.fps')} + placeholder="eg. 0" + className="pt-2" + min={0} + required + /> +
+ )} +
+ +
+ setJobConfig(value, 'config.process[0].sample.seed')} + placeholder="eg. 0" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.walk_seed')} + /> +
+
+ +
+ setJobConfig(value, 'config.process[0].train.skip_first_sample')} + /> +
+
+ setJobConfig(value, 'config.process[0].train.disable_sampling')} + /> +
+
+
+
+ +
+
+ {jobConfig.config.process[0].sample.samples.map((sample, i) => ( +
+
+
+
+
+ setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)} + placeholder="Enter prompt" + required + /> +
+ + {modelArch?.additionalSections?.includes('sample.ctrl_img') && ( +
{ + openAddImageModal(imagePath => { + console.log('Selected image path:', imagePath); + if (!imagePath) return; + setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`); + }); + }} + > + {!sample.ctrl_img && ( +
Add Control Image
+ )} +
+ )} +
+
+
+
+ +
+
+
+ ))} + +
+
+ + {status === 'success' &&

Training saved successfully!

} + {status === 'error' &&

Error saving training. Please try again.

} +
+ + {trainingBackend === 'hf-jobs' && ( +
+ { + console.log('HF Job submitted:', jobId, 'Local job ID:', localJobId); + if (onHFJobComplete) { + onHFJobComplete(jobId, localJobId); + } + }} + /> +
+ )} + + + + ); +} diff --git a/src/app/jobs/new/jobConfig.ts b/src/app/jobs/new/jobConfig.ts new file mode 100644 index 0000000000000000000000000000000000000000..df257bb985dad2eaada5d2913ab1e6347cf36ec1 --- /dev/null +++ b/src/app/jobs/new/jobConfig.ts @@ -0,0 +1,167 @@ +import { JobConfig, DatasetConfig } from '@/types'; + +export const defaultDatasetConfig: DatasetConfig = { + folder_path: '/path/to/images/folder', + control_path: null, + mask_path: null, + mask_min_value: 0.1, + default_caption: '', + caption_ext: 'txt', + caption_dropout_rate: 0.05, + cache_latents_to_disk: false, + is_reg: false, + network_weight: 1, + resolution: [512, 768, 1024], + controls: [], + shrink_video_to_frames: true, + num_frames: 1, + do_i2v: true, + flip_x: false, + flip_y: false, +}; + +export const defaultJobConfig: JobConfig = { + job: 'extension', + config: { + name: 'my_first_lora_v1', + process: [ + { + type: 'ui_trainer', + training_folder: 'output', + sqlite_db_path: './aitk_db.db', + device: 'cuda', + trigger_word: null, + performance_log_every: 10, + network: { + type: 'lora', + linear: 32, + linear_alpha: 32, + conv: 16, + conv_alpha: 16, + lokr_full_rank: true, + lokr_factor: -1, + network_kwargs: { + ignore_if_contains: [], + }, + }, + save: { + dtype: 'bf16', + save_every: 250, + max_step_saves_to_keep: 4, + save_format: 'diffusers', + push_to_hub: false, + }, + datasets: [defaultDatasetConfig], + train: { + batch_size: 1, + bypass_guidance_embedding: true, + steps: 3000, + gradient_accumulation: 1, + train_unet: true, + train_text_encoder: false, + gradient_checkpointing: true, + noise_scheduler: 'flowmatch', + optimizer: 'adamw8bit', + timestep_type: 'sigmoid', + content_or_style: 'balanced', + optimizer_params: { + weight_decay: 1e-4, + }, + unload_text_encoder: false, + cache_text_embeddings: false, + lr: 0.0001, + ema_config: { + use_ema: false, + ema_decay: 0.99, + }, + skip_first_sample: false, + disable_sampling: false, + dtype: 'bf16', + diff_output_preservation: false, + diff_output_preservation_multiplier: 1.0, + diff_output_preservation_class: 'person', + switch_boundary_every: 1, + }, + model: { + name_or_path: 'ostris/Flex.1-alpha', + quantize: true, + qtype: 'qfloat8', + quantize_te: true, + qtype_te: 'qfloat8', + arch: 'flex1', + low_vram: false, + model_kwargs: {}, + }, + sample: { + sampler: 'flowmatch', + sample_every: 250, + width: 1024, + height: 1024, + samples: [ + { + prompt: 'woman with red hair, playing chess at the park, bomb going off in the background' + }, + { + prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe', + }, + { + prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini', + }, + { + prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', + }, + { + prompt: 'a bear building a log cabin in the snow covered mountains', + }, + { + prompt: 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker', + }, + { + prompt: 'hipster man with a beard, building a chair, in a wood shop', + }, + { + prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', + }, + { + prompt: "a man holding a sign that says, 'this is a sign'", + }, + { + prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', + }, + ], + neg: '', + seed: 42, + walk_seed: true, + guidance_scale: 4, + sample_steps: 25, + num_frames: 1, + fps: 1, + }, + }, + ], + }, + meta: { + name: '[name]', + version: '1.0', + }, +}; + +export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => { + // upgrade prompt strings to samples + if ( + jobConfig?.config?.process && + jobConfig.config.process[0]?.sample && + Array.isArray(jobConfig.config.process[0].sample.prompts) && + jobConfig.config.process[0].sample.prompts.length > 0 + ) { + let newSamples = []; + for (const prompt of jobConfig.config.process[0].sample.prompts) { + newSamples.push({ + prompt: prompt, + }); + } + jobConfig.config.process[0].sample.samples = newSamples; + delete jobConfig.config.process[0].sample.prompts; + } + return jobConfig; +}; diff --git a/src/app/jobs/new/options.ts b/src/app/jobs/new/options.ts new file mode 100644 index 0000000000000000000000000000000000000000..71fdc9d8e767d2cbc078475d32b37e6996948199 --- /dev/null +++ b/src/app/jobs/new/options.ts @@ -0,0 +1,441 @@ +import { GroupedSelectOption, SelectOption } from '@/types'; + +type Control = 'depth' | 'line' | 'pose' | 'inpaint'; + +type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; +type AdditionalSections = + | 'datasets.control_path' + | 'datasets.do_i2v' + | 'sample.ctrl_img' + | 'datasets.num_frames' + | 'model.multistage' + | 'model.low_vram'; +type ModelGroup = 'image' | 'instruction' | 'video'; + +export interface ModelArch { + name: string; + label: string; + group: ModelGroup; + controls?: Control[]; + isVideoModel?: boolean; + defaults?: { [key: string]: any }; + disableSections?: DisableableSections[]; + additionalSections?: AdditionalSections[]; + accuracyRecoveryAdapters?: { [key: string]: string }; +} + +const defaultNameOrPath = ''; + +export const modelArchs: ModelArch[] = [ + { + name: 'flux', + label: 'FLUX.1', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'flux_kontext', + label: 'FLUX.1-Kontext-dev', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img'], + }, + { + name: 'flex1', + label: 'Flex.1', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].train.bypass_guidance_embedding': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'flex2', + label: 'Flex.2', + group: 'image', + controls: ['depth', 'line', 'pose', 'inpaint'], + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ostris/Flex.2-preview', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.model_kwargs': [ + { + invert_inpaint_mask_chance: 0.2, + inpaint_dropout: 0.5, + control_dropout: 0.5, + inpaint_random_chance: 0.2, + do_random_inpainting: true, + random_blur_mask: true, + random_dialate_mask: true, + }, + {}, + ], + 'config.process[0].train.bypass_guidance_embedding': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'chroma', + label: 'Chroma', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['lodestones/Chroma1-Base', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'wan21:1b', + label: 'Wan 2.1 (1.3B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-1.3B-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames', 'model.low_vram'], + }, + { + name: 'wan21_i2v:14b480p', + label: 'Wan 2.1 I2V (14B-480P)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], + }, + { + name: 'wan21_i2v:14b', + label: 'Wan 2.1 I2V (14B-720P)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], + }, + { + name: 'wan21:14b', + label: 'Wan 2.1 (14B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames', 'model.low_vram'], + }, + { + name: 'wan22_14b:t2v', + label: 'Wan 2.2 (14B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].model.model_kwargs': [ + { + train_high_noise: true, + train_low_noise: true, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'], + accuracyRecoveryAdapters: { + // '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors', + '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors', + }, + }, + { + name: 'wan22_14b_i2v', + label: 'Wan 2.2 I2V (14B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [16, 1], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.timestep_type': ['linear', 'sigmoid'], + 'config.process[0].model.model_kwargs': [ + { + train_high_noise: true, + train_low_noise: true, + }, + {}, + ], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'], + accuracyRecoveryAdapters: { + '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors', + }, + }, + { + name: 'wan22_5b', + label: 'Wan 2.2 TI2V (5B)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [121, 1], + 'config.process[0].sample.fps': [24, 1], + 'config.process[0].sample.width': [768, 1024], + 'config.process[0].sample.height': [768, 1024], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'], + }, + { + name: 'lumina2', + label: 'Lumina2', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + }, + disableSections: ['network.conv'], + }, + { + name: 'qwen_image', + label: 'Qwen-Image', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors', + }, + }, + { + name: 'qwen_image_edit', + label: 'Qwen-Image-Edit', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', + }, + }, + { + name: 'hidream', + label: 'HiDream', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.lr': [0.0002, 0.0001], + 'config.process[0].train.timestep_type': ['shift', 'sigmoid'], + 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], + }, + disableSections: ['network.conv'], + additionalSections: ['model.low_vram'], + }, + { + name: 'hidream_e1', + label: 'HiDream E1', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.lr': [0.0001, 0.0001], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'], + }, + { + name: 'sdxl', + label: 'SDXL', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [false, false], + 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], + 'config.process[0].sample.guidance_scale': [6, 4], + }, + disableSections: ['model.quantize', 'train.timestep_type'], + }, + { + name: 'sd15', + label: 'SD 1.5', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath], + 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], + 'config.process[0].sample.width': [512, 1024], + 'config.process[0].sample.height': [512, 1024], + 'config.process[0].sample.guidance_scale': [6, 4], + }, + disableSections: ['model.quantize', 'train.timestep_type'], + }, + { + name: 'omnigen2', + label: 'OmniGen2', + group: 'image', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img'], + }, +].sort((a, b) => { + // Sort by label, case-insensitive + return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }); +}) as any; + +export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => { + const group = acc.find(g => g.label === arch.group); + if (group) { + group.options.push({ value: arch.name, label: arch.label }); + } else { + acc.push({ + label: arch.group, + options: [{ value: arch.name, label: arch.label }], + }); + } + return acc; +}, [] as GroupedSelectOption[]); + +export const quantizationOptions: SelectOption[] = [ + { value: '', label: '- NONE -' }, + { value: 'qfloat8', label: 'float8 (default)' }, + { value: 'uint8', label: '8 bit' }, + { value: 'uint7', label: '7 bit' }, + { value: 'uint6', label: '6 bit' }, + { value: 'uint5', label: '5 bit' }, + { value: 'uint4', label: '4 bit' }, + { value: 'uint3', label: '3 bit' }, + { value: 'uint2', label: '2 bit' }, +]; + +export const defaultQtype = 'qfloat8'; diff --git a/src/app/jobs/new/page.tsx b/src/app/jobs/new/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..1da413490f5703ceb577dd5bb29502a9e3970045 --- /dev/null +++ b/src/app/jobs/new/page.tsx @@ -0,0 +1,306 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { useSearchParams, useRouter } from 'next/navigation'; +import Link from 'next/link'; +import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig'; +import { JobConfig } from '@/types'; +import { objectCopy } from '@/utils/basic'; +import { useNestedState } from '@/utils/hooks'; +import { SelectInput } from '@/components/formInputs'; +import useSettings from '@/hooks/useSettings'; +import useGPUInfo from '@/hooks/useGPUInfo'; +import useDatasetList from '@/hooks/useDatasetList'; +import path from 'path'; +import { TopBar, MainContent } from '@/components/layout'; +import { Button } from '@headlessui/react'; +import { FaChevronLeft } from 'react-icons/fa'; +import SimpleJob from './SimpleJob'; +import AdvancedJob from './AdvancedJob'; +import ErrorBoundary from '@/components/ErrorBoundary'; +import { getJob, upsertJob } from '@/utils/storage/jobStorage'; +import { usingBrowserDb } from '@/utils/env'; +import { getUserDatasetPath, updateUserDatasetPath } from '@/utils/storage/datasetStorage'; +import { apiClient } from '@/utils/api'; +import { useAuth } from '@/contexts/AuthContext'; +import HFLoginButton from '@/components/HFLoginButton'; + +const isDev = process.env.NODE_ENV === 'development'; + +export default function TrainingForm() { + const router = useRouter(); + const searchParams = useSearchParams(); + const runId = searchParams.get('id'); + const { status: authStatus } = useAuth(); + const isAuthenticated = authStatus === 'authenticated'; + const [gpuIDs, setGpuIDs] = useState(null); + const { settings, isSettingsLoaded } = useSettings(); + const { gpuList, isGPUInfoLoaded } = useGPUInfo(); + const { datasets, status: datasetFetchStatus } = useDatasetList(); + const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]); + const [showAdvancedView, setShowAdvancedView] = useState(false); + + const [jobConfig, setJobConfig] = useNestedState(objectCopy(defaultJobConfig)); + const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); + + // Track HF Jobs backend state + const [trainingBackend, setTrainingBackend] = useState<'local' | 'hf-jobs'>( + usingBrowserDb ? 'hf-jobs' : 'local', + ); + const [hfJobSubmitted, setHfJobSubmitted] = useState(false); + + useEffect(() => { + if (!isSettingsLoaded || !isAuthenticated) return; + if (datasetFetchStatus !== 'success') return; + + let isMounted = true; + + const buildDatasetOptions = async () => { + const options = await Promise.all( + datasets.map(async name => { + let datasetPath = settings.DATASETS_FOLDER ? path.join(settings.DATASETS_FOLDER, name) : ''; + + if (usingBrowserDb) { + const storedPath = getUserDatasetPath(name); + if (storedPath) { + datasetPath = storedPath; + } else { + try { + const response = await apiClient + .post('/api/datasets/create', { name }) + .then(res => res.data); + if (response?.path) { + datasetPath = response.path; + updateUserDatasetPath(name, datasetPath); + } + } catch (err) { + console.error('Error resolving dataset path:', err); + } + } + } + + if (!datasetPath) { + datasetPath = name; + } + + return { value: datasetPath, label: name }; + }), + ); + + if (!isMounted) { + return; + } + + setDatasetOptions(options); + const defaultDatasetPath = defaultDatasetConfig.folder_path; + + for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) { + const dataset = jobConfig.config.process[0].datasets[i]; + if (dataset.folder_path === defaultDatasetPath) { + if (options.length > 0) { + setJobConfig(options[0].value, `config.process[0].datasets[${i}].folder_path`); + } + } + } + }; + + buildDatasetOptions(); + + return () => { + isMounted = false; + }; + }, [datasets, settings, isSettingsLoaded, datasetFetchStatus]); + + useEffect(() => { + if (runId) { + getJob(runId) + .then(data => { + if (!data) { + throw new Error('Job not found'); + } + setGpuIDs(data.gpu_ids); + const parsedJobConfig = migrateJobConfig(JSON.parse(data.job_config)); + setJobConfig(parsedJobConfig); + + if (parsedJobConfig.is_hf_job) { + setTrainingBackend('hf-jobs'); + setHfJobSubmitted(true); + } + }) + .catch(error => console.error('Error fetching training:', error)); + } + }, [runId]); + + useEffect(() => { + if (isGPUInfoLoaded) { + if (gpuIDs === null && gpuList.length > 0) { + setGpuIDs(`${gpuList[0].index}`); + } + } + }, [gpuList, isGPUInfoLoaded]); + + useEffect(() => { + if (isSettingsLoaded) { + setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder'); + } + }, [settings, isSettingsLoaded]); + + const saveJob = async () => { + if (!isAuthenticated) return; + if (status === 'saving') return; + setStatus('saving'); + + try { + const savedJob = await upsertJob({ + id: runId || undefined, + name: jobConfig.config.name, + gpu_ids: gpuIDs, + job_config: { + ...jobConfig, + is_hf_job: trainingBackend === 'hf-jobs', + hf_job_submitted: hfJobSubmitted, + training_backend: trainingBackend, + }, + status: trainingBackend === 'hf-jobs' ? (hfJobSubmitted ? 'submitted' : 'stopped') : undefined, + }); + + setStatus('success'); + router.push(`/jobs/${savedJob.id}`); + } catch (error: any) { + console.log('Error saving training:', error); + if (error?.code === 'P2002') { + alert('Training name already exists. Please choose a different name.'); + } else { + alert('Failed to save job. Please try again.'); + } + } finally { + setTimeout(() => { + setStatus('idle'); + }, 2000); + } + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + saveJob(); + }; + + return ( + <> + +
+ +
+
+

{runId ? 'Edit Training Job' : 'New Training Job'}

+
+
+ {showAdvancedView && isAuthenticated && ( + <> +
+ setGpuIDs(value)} + options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} + /> +
+
+ + )} + +
+ +
+
+ +
+
+ + {!isAuthenticated ? ( + +
+

You need to sign in with Hugging Face or provide a valid access token before creating or editing jobs.

+
+ + + Manage authentication in Settings + +
+
+
+ ) : showAdvancedView ? ( +
+ +
+ ) : ( + + + Advanced job detected. Please switch to advanced view to continue. +
+ } + > + { + setHfJobSubmitted(true); + // Redirect to the job detail page + if (localJobId) { + router.push(`/jobs/${localJobId}`); + } + }} + forceHFBackend={usingBrowserDb} + /> + + +
+ + )} + + ); +} + useEffect(() => { + if (!isAuthenticated) { + setDatasetOptions([]); + } + }, [isAuthenticated]); diff --git a/src/app/jobs/page.tsx b/src/app/jobs/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..a29dd77c3a5463069f683f20f903add3b343fe40 --- /dev/null +++ b/src/app/jobs/page.tsx @@ -0,0 +1,49 @@ +'use client'; + +import JobsTable from '@/components/JobsTable'; +import { TopBar, MainContent } from '@/components/layout'; +import Link from 'next/link'; +import { useAuth } from '@/contexts/AuthContext'; +import HFLoginButton from '@/components/HFLoginButton'; + +export default function Dashboard() { + const { status: authStatus } = useAuth(); + const isAuthenticated = authStatus === 'authenticated'; + + return ( + <> + +
+

Training Jobs

+
+
+
+ {isAuthenticated ? ( + + New Training Job + + ) : ( + + Sign in to create jobs + + )} +
+
+ + {isAuthenticated ? ( + + ) : ( +
+

Sign in with Hugging Face or add a personal access token to view and manage training jobs.

+
+ + + Manage tokens in Settings + +
+
+ )} +
+ + ); +} diff --git a/src/app/layout.tsx b/src/app/layout.tsx new file mode 100644 index 0000000000000000000000000000000000000000..b3ce381e88faf2bbd71b8cb67f61662d6bead943 --- /dev/null +++ b/src/app/layout.tsx @@ -0,0 +1,50 @@ +import type { Metadata } from 'next'; +import { Inter } from 'next/font/google'; +import './globals.css'; +import Sidebar from '@/components/Sidebar'; +import { ThemeProvider } from '@/components/ThemeProvider'; +import ConfirmModal from '@/components/ConfirmModal'; +import SampleImageModal from '@/components/SampleImageModal'; +import { Suspense } from 'react'; +import AuthWrapper from '@/components/AuthWrapper'; +import DocModal from '@/components/DocModal'; +import { AuthProvider } from '@/contexts/AuthContext'; + +export const dynamic = 'force-dynamic'; + +const inter = Inter({ subsets: ['latin'] }); + +export const metadata: Metadata = { + title: 'Ostris - AI Toolkit', + description: 'A toolkit for building AI things.', +}; + +export default function RootLayout({ children }: { children: React.ReactNode }) { + // Check if the AI_TOOLKIT_AUTH environment variable is set + const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false; + + return ( + + + + + + + + +
+ +
+ {children} +
+
+
+
+
+ + + + + + ); +} diff --git a/src/app/manifest.json b/src/app/manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..ced3ca5d79e5ec230be33c6f0e0907fb419c5588 --- /dev/null +++ b/src/app/manifest.json @@ -0,0 +1,21 @@ +{ + "name": "AI Toolkit", + "short_name": "AIToolkit", + "icons": [ + { + "src": "/web-app-manifest-192x192.png", + "sizes": "192x192", + "type": "image/png", + "purpose": "maskable" + }, + { + "src": "/web-app-manifest-512x512.png", + "sizes": "512x512", + "type": "image/png", + "purpose": "maskable" + } + ], + "theme_color": "#000000", + "background_color": "#000000", + "display": "standalone" +} \ No newline at end of file diff --git a/src/app/page.tsx b/src/app/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..f889cb6122cb25d33bccef074ef51a6b26b692c9 --- /dev/null +++ b/src/app/page.tsx @@ -0,0 +1,5 @@ +import { redirect } from 'next/navigation'; + +export default function Home() { + redirect('/dashboard'); +} diff --git a/src/app/settings/page.tsx b/src/app/settings/page.tsx new file mode 100644 index 0000000000000000000000000000000000000000..25fc6bd922360cda4452a119810529a9c132b695 --- /dev/null +++ b/src/app/settings/page.tsx @@ -0,0 +1,264 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import useSettings from '@/hooks/useSettings'; +import { TopBar, MainContent } from '@/components/layout'; +import { persistSettings } from '@/utils/storage/settingsStorage'; +import { useAuth } from '@/contexts/AuthContext'; +import HFLoginButton from '@/components/HFLoginButton'; +import { useMemo } from 'react'; +import Link from 'next/link'; + +export default function Settings() { + const { settings, setSettings } = useSettings(); + const { status: authStatus, namespace, oauthAvailable, loginWithOAuth, logout, setManualToken, error: authError, token: authToken } = useAuth(); + const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); + const [manualToken, setManualTokenInput] = useState(settings.HF_TOKEN || ''); + const isAuthenticated = authStatus === 'authenticated'; + + useEffect(() => { + setManualTokenInput(settings.HF_TOKEN || ''); + }, [settings.HF_TOKEN]); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setStatus('saving'); + + persistSettings(settings) + .then(() => { + setStatus('success'); + }) + .catch(error => { + console.error('Error saving settings:', error); + setStatus('error'); + }) + .finally(() => { + setTimeout(() => setStatus('idle'), 2000); + }); + }; + + const handleChange = (e: React.ChangeEvent) => { + const { name, value } = e.target; + setSettings(prev => ({ ...prev, [name]: value })); + }; + + const handleManualSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + await setManualToken(manualToken); + }; + + const authDescription = useMemo(() => { + if (authStatus === 'checking') { + return 'Checking your Hugging Face session…'; + } + if (isAuthenticated) { + return `Connected as ${namespace}`; + } + return 'Sign in to use Hugging Face Jobs or submit your own access token.'; + }, [authStatus, isAuthenticated, namespace]); + + return ( + <> + +
+

Settings

+
+
+
+ {isAuthenticated ? ( + Welcome, {namespace || 'user'} + ) : ( + Authenticate to unlock training features + )} +
+
+ +
+
+
+
+

Sign in with Hugging Face

+

{authDescription}

+
+ {isAuthenticated && ( + Authenticated + )} +
+
+ {isAuthenticated ? ( + + ) : ( + <> + + {!oauthAvailable && ( + + OAuth is unavailable. Set HF_OAUTH_CLIENT_ID/SECRET on the server. + + )} + + )} +
+ {!isAuthenticated && authError && ( +

{authError}

+ )} +
+ +
+

Manual Token

+

+ Paste an access token created at{' '} + + huggingface.co/settings/tokens + + . +

+
+ setManualTokenInput(event.target.value)} + className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent" + placeholder="Enter Hugging Face token" + /> +
+
+ + {isAuthenticated && authToken === manualToken && ( + Active token + )} +
+ {authError && ( +

{authError}

+ )} +
+
+ +
+
+
+
+
+ + +
+ +
+ + +
+
+
+
+
+

Hugging Face Jobs (Cloud Training)

+ +
+ + +
+ +
+ + +
+
+
+
+ + + + {status === 'success' &&

Settings saved successfully!

} + {status === 'error' &&

Error saving settings. Please try again.

} +
+
+ + ); +} diff --git a/src/components/AddImagesModal.tsx b/src/components/AddImagesModal.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ff91a8836dcfe7dc67a9ee237d7d5a1b16941cf2 --- /dev/null +++ b/src/components/AddImagesModal.tsx @@ -0,0 +1,152 @@ +'use client'; +import { createGlobalState } from 'react-global-hooks'; +import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react'; +import { FaUpload } from 'react-icons/fa'; +import { useCallback, useState } from 'react'; +import { useDropzone } from 'react-dropzone'; +import { apiClient } from '@/utils/api'; + +export interface AddImagesModalState { + datasetName: string; + onComplete?: () => void; +} + +export const addImagesModalState = createGlobalState(null); + +export const openImagesModal = (datasetName: string, onComplete: () => void) => { + addImagesModalState.set({ datasetName, onComplete }); +}; + +export default function AddImagesModal() { + const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use(); + const [uploadProgress, setUploadProgress] = useState(0); + const [isUploading, setIsUploading] = useState(false); + const open = addImagesModalInfo !== null; + + const onCancel = () => { + if (!isUploading) { + setAddImagesModalInfo(null); + } + }; + + const onDone = () => { + if (addImagesModalInfo?.onComplete && !isUploading) { + addImagesModalInfo.onComplete(); + setAddImagesModalInfo(null); + } + }; + + const onDrop = useCallback( + async (acceptedFiles: File[]) => { + if (acceptedFiles.length === 0) return; + + setIsUploading(true); + setUploadProgress(0); + + const formData = new FormData(); + acceptedFiles.forEach(file => { + formData.append('files', file); + }); + formData.append('datasetName', addImagesModalInfo?.datasetName || ''); + + try { + await apiClient.post(`/api/datasets/upload`, formData, { + headers: { + 'Content-Type': 'multipart/form-data', + }, + onUploadProgress: progressEvent => { + const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100)); + setUploadProgress(percentCompleted); + }, + timeout: 0, // Disable timeout + }); + + onDone(); + } catch (error) { + console.error('Upload failed:', error); + } finally { + setIsUploading(false); + setUploadProgress(0); + } + }, + [addImagesModalInfo], + ); + + const { getRootProps, getInputProps, isDragActive } = useDropzone({ + onDrop, + accept: { + 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'], + 'video/*': ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv'], + 'text/*': ['.txt'], + }, + multiple: true, + }); + + return ( + + + +
+
+ +
+
+ + Add Images to: {addImagesModalInfo?.datasetName} + +
+
+ + +

+ {isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'} +

+
+ {isUploading && ( +
+
+
+
+

Uploading... {uploadProgress}%

+
+ )} +
+
+
+
+ + +
+
+
+
+
+ ); +} diff --git a/src/components/AddSingleImageModal.tsx b/src/components/AddSingleImageModal.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ba32ef9dff916b5f6e605909f5d328dfce49783a --- /dev/null +++ b/src/components/AddSingleImageModal.tsx @@ -0,0 +1,141 @@ +'use client'; +import { createGlobalState } from 'react-global-hooks'; +import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react'; +import { FaUpload } from 'react-icons/fa'; +import { useCallback, useState } from 'react'; +import { useDropzone } from 'react-dropzone'; +import { apiClient } from '@/utils/api'; + +export interface AddSingleImageModalState { + + onComplete?: (imagePath: string|null) => void; +} + +export const addSingleImageModalState = createGlobalState(null); + +export const openAddImageModal = (onComplete: (imagePath: string|null) => void) => { + addSingleImageModalState.set({onComplete }); +}; + +export default function AddSingleImageModal() { + const [addSingleImageModalInfo, setAddSingleImageModalInfo] = addSingleImageModalState.use(); + const [uploadProgress, setUploadProgress] = useState(0); + const [isUploading, setIsUploading] = useState(false); + const open = addSingleImageModalInfo !== null; + + const onCancel = () => { + if (!isUploading) { + setAddSingleImageModalInfo(null); + } + }; + + const onDone = (imagePath: string|null) => { + if (addSingleImageModalInfo?.onComplete && !isUploading) { + addSingleImageModalInfo.onComplete(imagePath); + setAddSingleImageModalInfo(null); + } + }; + + const onDrop = useCallback( + async (acceptedFiles: File[]) => { + if (acceptedFiles.length === 0) return; + + setIsUploading(true); + setUploadProgress(0); + + const formData = new FormData(); + acceptedFiles.forEach(file => { + formData.append('files', file); + }); + + try { + const resp = await apiClient.post(`/api/img/upload`, formData, { + headers: { + 'Content-Type': 'multipart/form-data', + }, + onUploadProgress: progressEvent => { + const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100)); + setUploadProgress(percentCompleted); + }, + timeout: 0, // Disable timeout + }); + console.log('Upload successful:', resp.data); + + onDone(resp.data.files[0] || null); + } catch (error) { + console.error('Upload failed:', error); + } finally { + setIsUploading(false); + setUploadProgress(0); + } + }, + [addSingleImageModalInfo], + ); + + const { getRootProps, getInputProps, isDragActive } = useDropzone({ + onDrop, + accept: { + 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'], + }, + multiple: false, + }); + + return ( + + + +
+
+ +
+
+ + Add Control Image + +
+
+ + +

+ {isDragActive ? 'Drop the image here...' : 'Drag & drop an image here, or click to select one'} +

+
+ {isUploading && ( +
+
+
+
+

Uploading... {uploadProgress}%

+
+ )} +
+
+
+
+ +
+
+
+
+
+ ); +} diff --git a/src/components/AuthWrapper.tsx b/src/components/AuthWrapper.tsx new file mode 100644 index 0000000000000000000000000000000000000000..bdf287a8dca4aa022b852680a13c8c3b0bb33926 --- /dev/null +++ b/src/components/AuthWrapper.tsx @@ -0,0 +1,166 @@ +'use client'; + +import { useState, useEffect, useRef } from 'react'; +import { apiClient, isAuthorizedState } from '@/utils/api'; +import { createGlobalState } from 'react-global-hooks'; + +interface AuthWrapperProps { + authRequired: boolean; + children: React.ReactNode | React.ReactNode[]; +} + +export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) { + const [token, setToken] = useState(''); + // start with true, and deauth if needed + const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use(); + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(''); + const [isBrowser, setIsBrowser] = useState(false); + const inputRef = useRef(null); + + const isAuthorized = authRequired ? isAuthorizedGlobal : true; + + // Set isBrowser to true when component mounts + useEffect(() => { + setIsBrowser(true); + // Get token from localStorage only after component has mounted + const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || ''; + setToken(storedToken); + checkAuth(); + }, []); + + // auto focus on input when not authorized + useEffect(() => { + if (isAuthorized) { + return; + } + setTimeout(() => { + if (inputRef.current) { + inputRef.current.focus(); + } + }, 100); + }, [isAuthorized]); + + const checkAuth = async () => { + // always get current stored token here to avoid state race conditions + const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || ''; + if (!authRequired || isLoading || currentToken === '') { + return; + } + setIsLoading(true); + setError(''); + try { + const response = await apiClient.get('/api/auth'); + if (response.data.isAuthenticated) { + setIsAuthorized(true); + } else { + setIsAuthorized(false); + setError('Invalid token. Please try again.'); + } + } catch (err) { + setIsAuthorized(false); + console.log(err); + setError('Invalid token. Please try again.'); + } + setIsLoading(false); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setError(''); + + if (!token.trim()) { + setError('Please enter your token'); + return; + } + + if (isBrowser) { + localStorage.setItem('AI_TOOLKIT_AUTH', token); + checkAuth(); + } + }; + + if (isAuthorized) { + return <>{children}; + } + + return ( +
+ {/* Left side - decorative or brand area */} +
+
+ {/* Replace with your own logo */} +
+ Ostris AI Toolkit +
+
+

AI Toolkit

+
+ + {/* Right side - login form */} +
+
+
+ {/* Mobile logo */} +
+ Ostris AI Toolkit +
+
+ +

AI Toolkit

+ +
+
+ + setToken(e.target.value)} + className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200" + placeholder="Enter your password" + /> +
+ The password is set with the environment variable AI_TOOLKIT_AUTH, the default is the super secure secret word "password" +
+
+ + {error && ( +
{error}
+ )} + + +
+
+
+
+ ); +} diff --git a/src/components/Card.tsx b/src/components/Card.tsx new file mode 100644 index 0000000000000000000000000000000000000000..13c7409b8be089a104eb6613664a188cb35d78d7 --- /dev/null +++ b/src/components/Card.tsx @@ -0,0 +1,15 @@ +interface CardProps { + title?: string; + children?: React.ReactNode; +} + +const Card: React.FC = ({ title, children }) => { + return ( +
+ {title &&

{title}

} + {children ? children : null} +
+ ); +}; + +export default Card; diff --git a/src/components/ConfirmModal.tsx b/src/components/ConfirmModal.tsx new file mode 100644 index 0000000000000000000000000000000000000000..6ecea8136accffeb9f312afb0130d2988ef485d3 --- /dev/null +++ b/src/components/ConfirmModal.tsx @@ -0,0 +1,201 @@ +'use client'; +import { useRef } from 'react'; +import { useState, useEffect } from 'react'; +import { createGlobalState } from 'react-global-hooks'; +import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react'; +import { FaExclamationTriangle, FaInfo } from 'react-icons/fa'; +import { TextInput } from './formInputs'; +import React from 'react'; +import { useFromNull } from '@/hooks/useFromNull'; +import classNames from 'classnames'; + +export interface ConfirmState { + title: string; + message?: string; + confirmText?: string; + type?: 'danger' | 'warning' | 'info'; + inputTitle?: string; + onConfirm?: (value?: string) => void | Promise; + onCancel?: () => void; +} + +export const confirmstate = createGlobalState(null); + +export const openConfirm = (confirmProps: ConfirmState) => { + confirmstate.set(confirmProps); +}; + +export default function ConfirmModal() { + const [confirm, setConfirm] = confirmstate.use(); + const [isOpen, setIsOpen] = useState(false); + const [inputValue, setInputValue] = useState(''); + const inputRef = useRef(null); + + useFromNull(() => { + setTimeout(() => { + if (inputRef.current) { + inputRef.current.focus(); + } + }, 100); + }, [confirm]); + + useEffect(() => { + if (confirm) { + setIsOpen(true); + setInputValue(''); + } + }, [confirm]); + + useEffect(() => { + if (!isOpen) { + // use timeout to allow the dialog to close before resetting the state + setTimeout(() => { + setConfirm(null); + }, 500); + } + }, [isOpen]); + + const onCancel = () => { + if (confirm?.onCancel) { + confirm.onCancel(); + } + setIsOpen(false); + }; + + const onConfirm = () => { + if (confirm?.onConfirm) { + confirm.onConfirm(inputValue); + } + setIsOpen(false); + }; + + let Icon = FaExclamationTriangle; + let color = confirm?.type || 'danger'; + + // Use conditional rendering for icon + if (color === 'info') { + Icon = FaInfo; + } + + // Color mapping for background colors + const getBgColor = () => { + switch (color) { + case 'danger': + return 'bg-red-500'; + case 'warning': + return 'bg-yellow-500'; + case 'info': + return 'bg-blue-500'; + default: + return 'bg-red-500'; + } + }; + + // Color mapping for text colors + const getTextColor = () => { + switch (color) { + case 'danger': + return 'text-red-950'; + case 'warning': + return 'text-yellow-950'; + case 'info': + return 'text-blue-950'; + default: + return 'text-red-950'; + } + }; + + // Color mapping for titles + const getTitleColor = () => { + switch (color) { + case 'danger': + return 'text-red-500'; + case 'warning': + return 'text-yellow-500'; + case 'info': + return 'text-blue-500'; + default: + return 'text-red-500'; + } + }; + + // Button background color mapping + const getButtonBgColor = () => { + switch (color) { + case 'danger': + return 'bg-red-700 hover:bg-red-500'; + case 'warning': + return 'bg-yellow-700 hover:bg-yellow-500'; + case 'info': + return 'bg-blue-700 hover:bg-blue-500'; + default: + return 'bg-red-700 hover:bg-red-500'; + } + }; + + return ( + + + +
+
+ +
+
+
+
+
+ + {confirm?.title} + +
+

{confirm?.message}

+
+
{ + e.preventDefault() + onConfirm() + }}> + + +
+
+
+
+
+
+ + +
+
+
+
+
+ ); +} diff --git a/src/components/DatasetImageCard.tsx b/src/components/DatasetImageCard.tsx new file mode 100644 index 0000000000000000000000000000000000000000..7eb562b5cd6edb7906f7e9e55507223ac5141878 --- /dev/null +++ b/src/components/DatasetImageCard.tsx @@ -0,0 +1,231 @@ +import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 'react'; +import { FaTrashAlt, FaEye, FaEyeSlash } from 'react-icons/fa'; +import { openConfirm } from './ConfirmModal'; +import classNames from 'classnames'; +import { apiClient } from '@/utils/api'; +import { isVideo } from '@/utils/basic'; + +interface DatasetImageCardProps { + imageUrl: string; + alt: string; + children?: ReactNode; + className?: string; + onDelete?: () => void; +} + +const DatasetImageCard: React.FC = ({ + imageUrl, + alt, + children, + className = '', + onDelete = () => {}, +}) => { + const cardRef = useRef(null); + const [isVisible, setIsVisible] = useState(false); + const [inViewport, setInViewport] = useState(false); + const [loaded, setLoaded] = useState(false); + const [isCaptionLoaded, setIsCaptionLoaded] = useState(false); + const [caption, setCaption] = useState(''); + const [savedCaption, setSavedCaption] = useState(''); + const isGettingCaption = useRef(false); + + const fetchCaption = async () => { + if (isGettingCaption.current || isCaptionLoaded) return; + isGettingCaption.current = true; + apiClient + .post(`/api/caption/get`, { imgPath: imageUrl }) + .then(res => res.data) + .then(data => { + console.log('Caption fetched:', data); + + setCaption(data || ''); + setSavedCaption(data || ''); + setIsCaptionLoaded(true); + }) + .catch(error => { + console.error('Error fetching caption:', error); + }) + .finally(() => { + isGettingCaption.current = false; + }); + }; + + const saveCaption = () => { + const trimmedCaption = caption.trim(); + if (trimmedCaption === savedCaption) return; + apiClient + .post('/api/img/caption', { imgPath: imageUrl, caption: trimmedCaption }) + .then(res => res.data) + .then(data => { + console.log('Caption saved:', data); + setSavedCaption(trimmedCaption); + }) + .catch(error => { + console.error('Error saving caption:', error); + }); + }; + + // Only fetch caption when the component is both in viewport and visible + useEffect(() => { + if (inViewport && isVisible) { + fetchCaption(); + } + }, [inViewport, isVisible]); + + useEffect(() => { + // Create intersection observer to check viewport visibility + const observer = new IntersectionObserver( + entries => { + if (entries[0].isIntersecting) { + setInViewport(true); + // Initialize isVisible to true when first coming into view + if (!isVisible) { + setIsVisible(true); + } + } else { + setInViewport(false); + } + }, + { threshold: 0.1 }, + ); + + if (cardRef.current) { + observer.observe(cardRef.current); + } + + return () => { + observer.disconnect(); + }; + }, []); + + const toggleVisibility = (): void => { + setIsVisible(prev => !prev); + if (!isVisible && !isCaptionLoaded) { + fetchCaption(); + } + }; + + const handleLoad = (): void => { + setLoaded(true); + }; + + const handleKeyDown = (e: KeyboardEvent): void => { + // If Enter is pressed without Shift, prevent default behavior and save + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault(); + saveCaption(); + } + }; + + const isCaptionCurrent = caption.trim() === savedCaption; + + const isItAVideo = isVideo(imageUrl); + + return ( +
+ {/* Square image container */} +
+
+ {inViewport && isVisible && ( + <> + {isItAVideo ? ( +
+ {inViewport && isVisible && ( +
+ {imageUrl} +
+ )} +
+
+ {inViewport && isVisible && isCaptionLoaded && ( +
{ + e.preventDefault(); + saveCaption(); + }} + onBlur={saveCaption} + > +