Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload 99 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/app/api/auth/hf/callback/route.ts +112 -0
- src/app/api/auth/hf/login/route.ts +36 -0
- src/app/api/auth/hf/validate/route.ts +22 -0
- src/app/api/auth/route.ts +6 -0
- src/app/api/caption/get/route.ts +46 -0
- src/app/api/datasets/create/route.tsx +25 -0
- src/app/api/datasets/delete/route.tsx +24 -0
- src/app/api/datasets/list/route.ts +25 -0
- src/app/api/datasets/listImages/route.ts +61 -0
- src/app/api/datasets/upload/route.ts +57 -0
- src/app/api/files/[...filePath]/route.ts +116 -0
- src/app/api/gpu/route.ts +121 -0
- src/app/api/hf-hub/route.ts +165 -0
- src/app/api/hf-jobs/route.ts +761 -0
- src/app/api/img/[...imagePath]/route.ts +78 -0
- src/app/api/img/caption/route.ts +29 -0
- src/app/api/img/delete/route.ts +34 -0
- src/app/api/img/upload/route.ts +58 -0
- src/app/api/jobs/[jobID]/delete/route.ts +32 -0
- src/app/api/jobs/[jobID]/files/route.ts +48 -0
- src/app/api/jobs/[jobID]/log/route.ts +35 -0
- src/app/api/jobs/[jobID]/samples/route.ts +40 -0
- src/app/api/jobs/[jobID]/start/route.ts +215 -0
- src/app/api/jobs/[jobID]/stop/route.ts +23 -0
- src/app/api/jobs/route.ts +67 -0
- src/app/api/settings/route.ts +59 -0
- src/app/api/zip/route.ts +78 -0
- src/app/apple-icon.png +0 -0
- src/app/dashboard/page.tsx +85 -0
- src/app/datasets/[datasetName]/page.tsx +190 -0
- src/app/datasets/page.tsx +217 -0
- src/app/favicon.ico +0 -0
- src/app/globals.css +72 -0
- src/app/icon.png +0 -0
- src/app/icon.svg +0 -0
- src/app/jobs/[jobID]/page.tsx +147 -0
- src/app/jobs/new/AdvancedJob.tsx +146 -0
- src/app/jobs/new/SimpleJob.tsx +973 -0
- src/app/jobs/new/jobConfig.ts +167 -0
- src/app/jobs/new/options.ts +441 -0
- src/app/jobs/new/page.tsx +306 -0
- src/app/jobs/page.tsx +49 -0
- src/app/layout.tsx +50 -0
- src/app/manifest.json +21 -0
- src/app/page.tsx +5 -0
- src/app/settings/page.tsx +264 -0
- src/components/AddImagesModal.tsx +152 -0
- src/components/AddSingleImageModal.tsx +141 -0
- src/components/AuthWrapper.tsx +166 -0
- src/components/Card.tsx +15 -0
src/app/api/auth/hf/callback/route.ts
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { cookies } from 'next/headers';
|
| 3 |
+
|
| 4 |
+
const TOKEN_ENDPOINT = 'https://huggingface.co/oauth/token';
|
| 5 |
+
const USERINFO_ENDPOINT = 'https://huggingface.co/oauth/userinfo';
|
| 6 |
+
const STATE_COOKIE = 'hf_oauth_state';
|
| 7 |
+
|
| 8 |
+
function htmlResponse(script: string) {
|
| 9 |
+
return new NextResponse(
|
| 10 |
+
`<!DOCTYPE html><html><body><script>${script}</script></body></html>`,
|
| 11 |
+
{
|
| 12 |
+
headers: { 'Content-Type': 'text/html; charset=utf-8' },
|
| 13 |
+
},
|
| 14 |
+
);
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
export async function GET(request: NextRequest) {
|
| 18 |
+
const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID;
|
| 19 |
+
const clientSecret = process.env.HF_OAUTH_CLIENT_SECRET;
|
| 20 |
+
|
| 21 |
+
if (!clientId || !clientSecret) {
|
| 22 |
+
return NextResponse.json({ error: 'OAuth application is not configured' }, { status: 500 });
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
const { searchParams } = new URL(request.url);
|
| 26 |
+
const code = searchParams.get('code');
|
| 27 |
+
const incomingState = searchParams.get('state');
|
| 28 |
+
|
| 29 |
+
const cookieStore = cookies();
|
| 30 |
+
const storedState = cookieStore.get(STATE_COOKIE)?.value;
|
| 31 |
+
|
| 32 |
+
cookieStore.delete(STATE_COOKIE);
|
| 33 |
+
|
| 34 |
+
const origin = request.nextUrl.origin;
|
| 35 |
+
|
| 36 |
+
if (!code || !incomingState || !storedState || incomingState !== storedState) {
|
| 37 |
+
const script = `
|
| 38 |
+
window.opener && window.opener.postMessage({
|
| 39 |
+
type: 'HF_OAUTH_ERROR',
|
| 40 |
+
payload: { message: 'Invalid or expired OAuth state.' }
|
| 41 |
+
}, '${origin}');
|
| 42 |
+
window.close();
|
| 43 |
+
`;
|
| 44 |
+
return htmlResponse(script.trim());
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`;
|
| 48 |
+
|
| 49 |
+
try {
|
| 50 |
+
const tokenResponse = await fetch(TOKEN_ENDPOINT, {
|
| 51 |
+
method: 'POST',
|
| 52 |
+
headers: {
|
| 53 |
+
'Content-Type': 'application/x-www-form-urlencoded',
|
| 54 |
+
},
|
| 55 |
+
body: new URLSearchParams({
|
| 56 |
+
grant_type: 'authorization_code',
|
| 57 |
+
code,
|
| 58 |
+
redirect_uri: redirectUri,
|
| 59 |
+
client_id: clientId,
|
| 60 |
+
client_secret: clientSecret,
|
| 61 |
+
}),
|
| 62 |
+
});
|
| 63 |
+
|
| 64 |
+
if (!tokenResponse.ok) {
|
| 65 |
+
const errorPayload = await tokenResponse.json().catch(() => ({}));
|
| 66 |
+
throw new Error(errorPayload?.error_description || 'Failed to exchange code for token');
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const tokenData = await tokenResponse.json();
|
| 70 |
+
const accessToken = tokenData?.access_token;
|
| 71 |
+
if (!accessToken) {
|
| 72 |
+
throw new Error('Access token missing in response');
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
const userResponse = await fetch(USERINFO_ENDPOINT, {
|
| 76 |
+
headers: {
|
| 77 |
+
Authorization: `Bearer ${accessToken}`,
|
| 78 |
+
},
|
| 79 |
+
});
|
| 80 |
+
|
| 81 |
+
if (!userResponse.ok) {
|
| 82 |
+
throw new Error('Failed to fetch user info');
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
const profile = await userResponse.json();
|
| 86 |
+
const namespace = profile?.preferred_username || profile?.name || 'user';
|
| 87 |
+
|
| 88 |
+
const script = `
|
| 89 |
+
window.opener && window.opener.postMessage({
|
| 90 |
+
type: 'HF_OAUTH_SUCCESS',
|
| 91 |
+
payload: {
|
| 92 |
+
token: ${JSON.stringify(accessToken)},
|
| 93 |
+
namespace: ${JSON.stringify(namespace)},
|
| 94 |
+
}
|
| 95 |
+
}, '${origin}');
|
| 96 |
+
window.close();
|
| 97 |
+
`;
|
| 98 |
+
|
| 99 |
+
return htmlResponse(script.trim());
|
| 100 |
+
} catch (error: any) {
|
| 101 |
+
const message = error?.message || 'OAuth flow failed';
|
| 102 |
+
const script = `
|
| 103 |
+
window.opener && window.opener.postMessage({
|
| 104 |
+
type: 'HF_OAUTH_ERROR',
|
| 105 |
+
payload: { message: ${JSON.stringify(message)} }
|
| 106 |
+
}, '${origin}');
|
| 107 |
+
window.close();
|
| 108 |
+
`;
|
| 109 |
+
|
| 110 |
+
return htmlResponse(script.trim());
|
| 111 |
+
}
|
| 112 |
+
}
|
src/app/api/auth/hf/login/route.ts
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { randomUUID } from 'crypto';
|
| 2 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 3 |
+
|
| 4 |
+
const HF_AUTHORIZE_URL = 'https://huggingface.co/oauth/authorize';
|
| 5 |
+
const STATE_COOKIE = 'hf_oauth_state';
|
| 6 |
+
|
| 7 |
+
export async function GET(request: NextRequest) {
|
| 8 |
+
const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID;
|
| 9 |
+
if (!clientId) {
|
| 10 |
+
return NextResponse.json({ error: 'OAuth client ID not configured' }, { status: 500 });
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
const state = randomUUID();
|
| 14 |
+
const origin = request.nextUrl.origin;
|
| 15 |
+
const redirectUri = process.env.HF_OAUTH_REDIRECT_URI || process.env.NEXT_PUBLIC_HF_OAUTH_REDIRECT_URI || `${origin}/api/auth/hf/callback`;
|
| 16 |
+
|
| 17 |
+
const authorizeUrl = new URL(HF_AUTHORIZE_URL);
|
| 18 |
+
authorizeUrl.searchParams.set('response_type', 'code');
|
| 19 |
+
authorizeUrl.searchParams.set('client_id', clientId);
|
| 20 |
+
authorizeUrl.searchParams.set('redirect_uri', redirectUri);
|
| 21 |
+
authorizeUrl.searchParams.set('scope', 'openid profile read-repos');
|
| 22 |
+
authorizeUrl.searchParams.set('state', state);
|
| 23 |
+
|
| 24 |
+
const response = NextResponse.redirect(authorizeUrl.toString(), { status: 302 });
|
| 25 |
+
response.cookies.set({
|
| 26 |
+
name: STATE_COOKIE,
|
| 27 |
+
value: state,
|
| 28 |
+
httpOnly: true,
|
| 29 |
+
sameSite: 'lax',
|
| 30 |
+
secure: process.env.NODE_ENV === 'production',
|
| 31 |
+
maxAge: 60 * 5,
|
| 32 |
+
path: '/',
|
| 33 |
+
});
|
| 34 |
+
|
| 35 |
+
return response;
|
| 36 |
+
}
|
src/app/api/auth/hf/validate/route.ts
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { whoAmI } from '@huggingface/hub';
|
| 3 |
+
|
| 4 |
+
export async function POST(request: NextRequest) {
|
| 5 |
+
try {
|
| 6 |
+
const body = await request.json().catch(() => ({}));
|
| 7 |
+
const token = (body?.token || '').trim();
|
| 8 |
+
|
| 9 |
+
if (!token) {
|
| 10 |
+
return NextResponse.json({ error: 'Token is required' }, { status: 400 });
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
const info = await whoAmI({ accessToken: token });
|
| 14 |
+
return NextResponse.json({
|
| 15 |
+
name: info?.name || info?.username || 'user',
|
| 16 |
+
email: info?.email || null,
|
| 17 |
+
orgs: info?.orgs || [],
|
| 18 |
+
});
|
| 19 |
+
} catch (error: any) {
|
| 20 |
+
return NextResponse.json({ error: error?.message || 'Invalid token' }, { status: 401 });
|
| 21 |
+
}
|
| 22 |
+
}
|
src/app/api/auth/route.ts
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
|
| 3 |
+
export async function GET() {
|
| 4 |
+
// if this gets hit, auth has already been verified
|
| 5 |
+
return NextResponse.json({ isAuthenticated: true });
|
| 6 |
+
}
|
src/app/api/caption/get/route.ts
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* eslint-disable */
|
| 2 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 3 |
+
import fs from 'fs';
|
| 4 |
+
import path from 'path';
|
| 5 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 6 |
+
|
| 7 |
+
export async function POST(request: NextRequest) {
|
| 8 |
+
|
| 9 |
+
const body = await request.json();
|
| 10 |
+
const { imgPath } = body;
|
| 11 |
+
console.log('Received POST request for caption:', imgPath);
|
| 12 |
+
try {
|
| 13 |
+
// Decode the path
|
| 14 |
+
const filepath = imgPath;
|
| 15 |
+
console.log('Decoded image path:', filepath);
|
| 16 |
+
|
| 17 |
+
// caption name is the filepath without extension but with .txt
|
| 18 |
+
const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.txt';
|
| 19 |
+
|
| 20 |
+
// Get allowed directories
|
| 21 |
+
const allowedDir = await getDatasetsRoot();
|
| 22 |
+
|
| 23 |
+
// Security check: Ensure path is in allowed directory
|
| 24 |
+
const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..');
|
| 25 |
+
|
| 26 |
+
if (!isAllowed) {
|
| 27 |
+
console.warn(`Access denied: ${filepath} not in ${allowedDir}`);
|
| 28 |
+
return new NextResponse('Access denied', { status: 403 });
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// Check if file exists
|
| 32 |
+
if (!fs.existsSync(captionPath)) {
|
| 33 |
+
// send back blank string if caption file does not exist
|
| 34 |
+
return new NextResponse('');
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Read caption file
|
| 38 |
+
const caption = fs.readFileSync(captionPath, 'utf-8');
|
| 39 |
+
|
| 40 |
+
// Return caption
|
| 41 |
+
return new NextResponse(caption);
|
| 42 |
+
} catch (error) {
|
| 43 |
+
console.error('Error getting caption:', error);
|
| 44 |
+
return new NextResponse('Error getting caption', { status: 500 });
|
| 45 |
+
}
|
| 46 |
+
}
|
src/app/api/datasets/create/route.tsx
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import fs from 'fs';
|
| 3 |
+
import path from 'path';
|
| 4 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 5 |
+
|
| 6 |
+
export async function POST(request: Request) {
|
| 7 |
+
try {
|
| 8 |
+
const body = await request.json();
|
| 9 |
+
let { name } = body;
|
| 10 |
+
// clean name by making lower case, removing special characters, and replacing spaces with underscores
|
| 11 |
+
name = name.toLowerCase().replace(/[^a-z0-9]+/g, '_');
|
| 12 |
+
|
| 13 |
+
let datasetsPath = await getDatasetsRoot();
|
| 14 |
+
let datasetPath = path.join(datasetsPath, name);
|
| 15 |
+
|
| 16 |
+
// if folder doesnt exist, create it
|
| 17 |
+
if (!fs.existsSync(datasetPath)) {
|
| 18 |
+
fs.mkdirSync(datasetPath);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
return NextResponse.json({ success: true, name: name, path: datasetPath });
|
| 22 |
+
} catch (error) {
|
| 23 |
+
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
| 24 |
+
}
|
| 25 |
+
}
|
src/app/api/datasets/delete/route.tsx
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import fs from 'fs';
|
| 3 |
+
import path from 'path';
|
| 4 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 5 |
+
|
| 6 |
+
export async function POST(request: Request) {
|
| 7 |
+
try {
|
| 8 |
+
const body = await request.json();
|
| 9 |
+
const { name } = body;
|
| 10 |
+
let datasetsPath = await getDatasetsRoot();
|
| 11 |
+
let datasetPath = path.join(datasetsPath, name);
|
| 12 |
+
|
| 13 |
+
// if folder doesnt exist, ignore
|
| 14 |
+
if (!fs.existsSync(datasetPath)) {
|
| 15 |
+
return NextResponse.json({ success: true });
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
// delete it and return success
|
| 19 |
+
fs.rmdirSync(datasetPath, { recursive: true });
|
| 20 |
+
return NextResponse.json({ success: true });
|
| 21 |
+
} catch (error) {
|
| 22 |
+
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
| 23 |
+
}
|
| 24 |
+
}
|
src/app/api/datasets/list/route.ts
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import fs from 'fs';
|
| 3 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 4 |
+
|
| 5 |
+
export async function GET() {
|
| 6 |
+
try {
|
| 7 |
+
let datasetsPath = await getDatasetsRoot();
|
| 8 |
+
|
| 9 |
+
// if folder doesnt exist, create it
|
| 10 |
+
if (!fs.existsSync(datasetsPath)) {
|
| 11 |
+
fs.mkdirSync(datasetsPath);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
// find all the folders in the datasets folder
|
| 15 |
+
let folders = fs
|
| 16 |
+
.readdirSync(datasetsPath, { withFileTypes: true })
|
| 17 |
+
.filter(dirent => dirent.isDirectory())
|
| 18 |
+
.filter(dirent => !dirent.name.startsWith('.'))
|
| 19 |
+
.map(dirent => dirent.name);
|
| 20 |
+
|
| 21 |
+
return NextResponse.json(folders);
|
| 22 |
+
} catch (error) {
|
| 23 |
+
return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 });
|
| 24 |
+
}
|
| 25 |
+
}
|
src/app/api/datasets/listImages/route.ts
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import fs from 'fs';
|
| 3 |
+
import path from 'path';
|
| 4 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 5 |
+
|
| 6 |
+
export async function POST(request: Request) {
|
| 7 |
+
const datasetsPath = await getDatasetsRoot();
|
| 8 |
+
const body = await request.json();
|
| 9 |
+
const { datasetName } = body;
|
| 10 |
+
const datasetFolder = path.join(datasetsPath, datasetName);
|
| 11 |
+
|
| 12 |
+
try {
|
| 13 |
+
// Check if folder exists
|
| 14 |
+
if (!fs.existsSync(datasetFolder)) {
|
| 15 |
+
return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 });
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
// Find all images recursively
|
| 19 |
+
const imageFiles = findImagesRecursively(datasetFolder);
|
| 20 |
+
|
| 21 |
+
// Format response
|
| 22 |
+
const result = imageFiles.map(imgPath => ({
|
| 23 |
+
img_path: imgPath,
|
| 24 |
+
}));
|
| 25 |
+
|
| 26 |
+
return NextResponse.json({ images: result });
|
| 27 |
+
} catch (error) {
|
| 28 |
+
console.error('Error finding images:', error);
|
| 29 |
+
return NextResponse.json({ error: 'Failed to process request' }, { status: 500 });
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
/**
|
| 34 |
+
* Recursively finds all image files in a directory and its subdirectories
|
| 35 |
+
* @param dir Directory to search
|
| 36 |
+
* @returns Array of absolute paths to image files
|
| 37 |
+
*/
|
| 38 |
+
function findImagesRecursively(dir: string): string[] {
|
| 39 |
+
const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp', '.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv'];
|
| 40 |
+
let results: string[] = [];
|
| 41 |
+
|
| 42 |
+
const items = fs.readdirSync(dir);
|
| 43 |
+
|
| 44 |
+
for (const item of items) {
|
| 45 |
+
const itemPath = path.join(dir, item);
|
| 46 |
+
const stat = fs.statSync(itemPath);
|
| 47 |
+
|
| 48 |
+
if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) {
|
| 49 |
+
// If it's a directory, recursively search it
|
| 50 |
+
results = results.concat(findImagesRecursively(itemPath));
|
| 51 |
+
} else {
|
| 52 |
+
// If it's a file, check if it's an image
|
| 53 |
+
const ext = path.extname(itemPath).toLowerCase();
|
| 54 |
+
if (imageExtensions.includes(ext)) {
|
| 55 |
+
results.push(itemPath);
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
return results;
|
| 61 |
+
}
|
src/app/api/datasets/upload/route.ts
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/app/api/datasets/upload/route.ts
|
| 2 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 3 |
+
import { writeFile, mkdir } from 'fs/promises';
|
| 4 |
+
import { join } from 'path';
|
| 5 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 6 |
+
|
| 7 |
+
export async function POST(request: NextRequest) {
|
| 8 |
+
try {
|
| 9 |
+
const datasetsPath = await getDatasetsRoot();
|
| 10 |
+
if (!datasetsPath) {
|
| 11 |
+
return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 });
|
| 12 |
+
}
|
| 13 |
+
const formData = await request.formData();
|
| 14 |
+
const files = formData.getAll('files');
|
| 15 |
+
const datasetName = formData.get('datasetName') as string;
|
| 16 |
+
|
| 17 |
+
if (!files || files.length === 0) {
|
| 18 |
+
return NextResponse.json({ error: 'No files provided' }, { status: 400 });
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// Create upload directory if it doesn't exist
|
| 22 |
+
const uploadDir = join(datasetsPath, datasetName);
|
| 23 |
+
await mkdir(uploadDir, { recursive: true });
|
| 24 |
+
|
| 25 |
+
const savedFiles: string[] = [];
|
| 26 |
+
|
| 27 |
+
// Process files sequentially to avoid overwhelming the system
|
| 28 |
+
for (let i = 0; i < files.length; i++) {
|
| 29 |
+
const file = files[i] as any;
|
| 30 |
+
const bytes = await file.arrayBuffer();
|
| 31 |
+
const buffer = Buffer.from(bytes);
|
| 32 |
+
|
| 33 |
+
// Clean filename and ensure it's unique
|
| 34 |
+
const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_');
|
| 35 |
+
const filePath = join(uploadDir, fileName);
|
| 36 |
+
|
| 37 |
+
await writeFile(filePath, buffer);
|
| 38 |
+
savedFiles.push(fileName);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
return NextResponse.json({
|
| 42 |
+
message: 'Files uploaded successfully',
|
| 43 |
+
files: savedFiles,
|
| 44 |
+
});
|
| 45 |
+
} catch (error) {
|
| 46 |
+
console.error('Upload error:', error);
|
| 47 |
+
return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// Increase payload size limit (default is 4mb)
|
| 52 |
+
export const config = {
|
| 53 |
+
api: {
|
| 54 |
+
bodyParser: false,
|
| 55 |
+
responseLimit: '50mb',
|
| 56 |
+
},
|
| 57 |
+
};
|
src/app/api/files/[...filePath]/route.ts
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* eslint-disable */
|
| 2 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 3 |
+
import fs from 'fs';
|
| 4 |
+
import path from 'path';
|
| 5 |
+
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
|
| 6 |
+
|
| 7 |
+
export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) {
|
| 8 |
+
const { filePath } = await params;
|
| 9 |
+
try {
|
| 10 |
+
// Decode the path
|
| 11 |
+
const decodedFilePath = decodeURIComponent(filePath);
|
| 12 |
+
|
| 13 |
+
// Get allowed directories
|
| 14 |
+
const datasetRoot = await getDatasetsRoot();
|
| 15 |
+
const trainingRoot = await getTrainingFolder();
|
| 16 |
+
const allowedDirs = [datasetRoot, trainingRoot];
|
| 17 |
+
|
| 18 |
+
// Security check: Ensure path is in allowed directory
|
| 19 |
+
const isAllowed =
|
| 20 |
+
allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
|
| 21 |
+
|
| 22 |
+
if (!isAllowed) {
|
| 23 |
+
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
|
| 24 |
+
return new NextResponse('Access denied', { status: 403 });
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// Check if file exists
|
| 28 |
+
if (!fs.existsSync(decodedFilePath)) {
|
| 29 |
+
console.warn(`File not found: ${decodedFilePath}`);
|
| 30 |
+
return new NextResponse('File not found', { status: 404 });
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// Get file info
|
| 34 |
+
const stat = fs.statSync(decodedFilePath);
|
| 35 |
+
if (!stat.isFile()) {
|
| 36 |
+
return new NextResponse('Not a file', { status: 400 });
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Get filename for Content-Disposition
|
| 40 |
+
const filename = path.basename(decodedFilePath);
|
| 41 |
+
|
| 42 |
+
// Determine content type
|
| 43 |
+
const ext = path.extname(decodedFilePath).toLowerCase();
|
| 44 |
+
const contentTypeMap: { [key: string]: string } = {
|
| 45 |
+
'.jpg': 'image/jpeg',
|
| 46 |
+
'.jpeg': 'image/jpeg',
|
| 47 |
+
'.png': 'image/png',
|
| 48 |
+
'.gif': 'image/gif',
|
| 49 |
+
'.webp': 'image/webp',
|
| 50 |
+
'.svg': 'image/svg+xml',
|
| 51 |
+
'.bmp': 'image/bmp',
|
| 52 |
+
'.safetensors': 'application/octet-stream',
|
| 53 |
+
'.zip': 'application/zip',
|
| 54 |
+
// Videos
|
| 55 |
+
'.mp4': 'video/mp4',
|
| 56 |
+
'.avi': 'video/x-msvideo',
|
| 57 |
+
'.mov': 'video/quicktime',
|
| 58 |
+
'.mkv': 'video/x-matroska',
|
| 59 |
+
'.wmv': 'video/x-ms-wmv',
|
| 60 |
+
'.m4v': 'video/x-m4v',
|
| 61 |
+
'.flv': 'video/x-flv'
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
const contentType = contentTypeMap[ext] || 'application/octet-stream';
|
| 65 |
+
|
| 66 |
+
// Get range header for partial content support
|
| 67 |
+
const range = request.headers.get('range');
|
| 68 |
+
|
| 69 |
+
// Common headers for better download handling
|
| 70 |
+
const commonHeaders = {
|
| 71 |
+
'Content-Type': contentType,
|
| 72 |
+
'Accept-Ranges': 'bytes',
|
| 73 |
+
'Cache-Control': 'public, max-age=86400',
|
| 74 |
+
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
|
| 75 |
+
'X-Content-Type-Options': 'nosniff',
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
if (range) {
|
| 79 |
+
// Parse range header
|
| 80 |
+
const parts = range.replace(/bytes=/, '').split('-');
|
| 81 |
+
const start = parseInt(parts[0], 10);
|
| 82 |
+
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
|
| 83 |
+
const chunkSize = end - start + 1;
|
| 84 |
+
|
| 85 |
+
const fileStream = fs.createReadStream(decodedFilePath, {
|
| 86 |
+
start,
|
| 87 |
+
end,
|
| 88 |
+
highWaterMark: 64 * 1024, // 64KB buffer
|
| 89 |
+
});
|
| 90 |
+
|
| 91 |
+
return new NextResponse(fileStream as any, {
|
| 92 |
+
status: 206,
|
| 93 |
+
headers: {
|
| 94 |
+
...commonHeaders,
|
| 95 |
+
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
|
| 96 |
+
'Content-Length': String(chunkSize),
|
| 97 |
+
},
|
| 98 |
+
});
|
| 99 |
+
} else {
|
| 100 |
+
// For full file download, read directly without streaming wrapper
|
| 101 |
+
const fileStream = fs.createReadStream(decodedFilePath, {
|
| 102 |
+
highWaterMark: 64 * 1024, // 64KB buffer
|
| 103 |
+
});
|
| 104 |
+
|
| 105 |
+
return new NextResponse(fileStream as any, {
|
| 106 |
+
headers: {
|
| 107 |
+
...commonHeaders,
|
| 108 |
+
'Content-Length': String(stat.size),
|
| 109 |
+
},
|
| 110 |
+
});
|
| 111 |
+
}
|
| 112 |
+
} catch (error) {
|
| 113 |
+
console.error('Error serving file:', error);
|
| 114 |
+
return new NextResponse('Internal Server Error', { status: 500 });
|
| 115 |
+
}
|
| 116 |
+
}
|
src/app/api/gpu/route.ts
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import { exec } from 'child_process';
|
| 3 |
+
import { promisify } from 'util';
|
| 4 |
+
import os from 'os';
|
| 5 |
+
|
| 6 |
+
const execAsync = promisify(exec);
|
| 7 |
+
|
| 8 |
+
export async function GET() {
|
| 9 |
+
try {
|
| 10 |
+
// Get platform
|
| 11 |
+
const platform = os.platform();
|
| 12 |
+
const isWindows = platform === 'win32';
|
| 13 |
+
|
| 14 |
+
// Check if nvidia-smi is available
|
| 15 |
+
const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
|
| 16 |
+
|
| 17 |
+
if (!hasNvidiaSmi) {
|
| 18 |
+
return NextResponse.json({
|
| 19 |
+
hasNvidiaSmi: false,
|
| 20 |
+
gpus: [],
|
| 21 |
+
error: 'nvidia-smi not found or not accessible',
|
| 22 |
+
});
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// Get GPU stats
|
| 26 |
+
const gpuStats = await getGpuStats(isWindows);
|
| 27 |
+
|
| 28 |
+
return NextResponse.json({
|
| 29 |
+
hasNvidiaSmi: true,
|
| 30 |
+
gpus: gpuStats,
|
| 31 |
+
});
|
| 32 |
+
} catch (error) {
|
| 33 |
+
console.error('Error fetching NVIDIA GPU stats:', error);
|
| 34 |
+
return NextResponse.json(
|
| 35 |
+
{
|
| 36 |
+
hasNvidiaSmi: false,
|
| 37 |
+
gpus: [],
|
| 38 |
+
error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`,
|
| 39 |
+
},
|
| 40 |
+
{ status: 500 },
|
| 41 |
+
);
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
|
| 46 |
+
try {
|
| 47 |
+
if (isWindows) {
|
| 48 |
+
// Check if nvidia-smi is available on Windows
|
| 49 |
+
// It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe
|
| 50 |
+
// but we'll just try to run it directly as it may be in PATH
|
| 51 |
+
await execAsync('nvidia-smi -L');
|
| 52 |
+
} else {
|
| 53 |
+
// Linux/macOS check
|
| 54 |
+
await execAsync('which nvidia-smi');
|
| 55 |
+
}
|
| 56 |
+
return true;
|
| 57 |
+
} catch (error) {
|
| 58 |
+
return false;
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
async function getGpuStats(isWindows: boolean) {
|
| 63 |
+
// Command is the same for both platforms, but the path might be different
|
| 64 |
+
const command =
|
| 65 |
+
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
|
| 66 |
+
|
| 67 |
+
// Execute command
|
| 68 |
+
const { stdout } = await execAsync(command);
|
| 69 |
+
|
| 70 |
+
// Parse CSV output
|
| 71 |
+
const gpus = stdout
|
| 72 |
+
.trim()
|
| 73 |
+
.split('\n')
|
| 74 |
+
.map(line => {
|
| 75 |
+
const [
|
| 76 |
+
index,
|
| 77 |
+
name,
|
| 78 |
+
driverVersion,
|
| 79 |
+
temperature,
|
| 80 |
+
gpuUtil,
|
| 81 |
+
memoryUtil,
|
| 82 |
+
memoryTotal,
|
| 83 |
+
memoryFree,
|
| 84 |
+
memoryUsed,
|
| 85 |
+
powerDraw,
|
| 86 |
+
powerLimit,
|
| 87 |
+
clockGraphics,
|
| 88 |
+
clockMemory,
|
| 89 |
+
fanSpeed,
|
| 90 |
+
] = line.split(', ').map(item => item.trim());
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
index: parseInt(index),
|
| 94 |
+
name,
|
| 95 |
+
driverVersion,
|
| 96 |
+
temperature: parseInt(temperature),
|
| 97 |
+
utilization: {
|
| 98 |
+
gpu: parseInt(gpuUtil),
|
| 99 |
+
memory: parseInt(memoryUtil),
|
| 100 |
+
},
|
| 101 |
+
memory: {
|
| 102 |
+
total: parseInt(memoryTotal),
|
| 103 |
+
free: parseInt(memoryFree),
|
| 104 |
+
used: parseInt(memoryUsed),
|
| 105 |
+
},
|
| 106 |
+
power: {
|
| 107 |
+
draw: parseFloat(powerDraw),
|
| 108 |
+
limit: parseFloat(powerLimit),
|
| 109 |
+
},
|
| 110 |
+
clocks: {
|
| 111 |
+
graphics: parseInt(clockGraphics),
|
| 112 |
+
memory: parseInt(clockMemory),
|
| 113 |
+
},
|
| 114 |
+
fan: {
|
| 115 |
+
speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0
|
| 116 |
+
},
|
| 117 |
+
};
|
| 118 |
+
});
|
| 119 |
+
|
| 120 |
+
return gpus;
|
| 121 |
+
}
|
src/app/api/hf-hub/route.ts
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { whoAmI, createRepo, uploadFiles, datasetInfo } from '@huggingface/hub';
|
| 3 |
+
import { readdir, stat } from 'fs/promises';
|
| 4 |
+
import path from 'path';
|
| 5 |
+
|
| 6 |
+
export async function POST(request: NextRequest) {
|
| 7 |
+
try {
|
| 8 |
+
const body = await request.json();
|
| 9 |
+
const { action, token, namespace, datasetName, datasetPath, datasetId } = body;
|
| 10 |
+
|
| 11 |
+
if (!token) {
|
| 12 |
+
return NextResponse.json({ error: 'HF token is required' }, { status: 400 });
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
switch (action) {
|
| 16 |
+
case 'whoami':
|
| 17 |
+
try {
|
| 18 |
+
const user = await whoAmI({ accessToken: token });
|
| 19 |
+
return NextResponse.json({ user });
|
| 20 |
+
} catch (error) {
|
| 21 |
+
return NextResponse.json({ error: 'Invalid token or network error' }, { status: 401 });
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
case 'createDataset':
|
| 25 |
+
try {
|
| 26 |
+
if (!namespace || !datasetName) {
|
| 27 |
+
return NextResponse.json({ error: 'Namespace and dataset name required' }, { status: 400 });
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
const repoId = `datasets/${namespace}/${datasetName}`;
|
| 31 |
+
|
| 32 |
+
// Create repository
|
| 33 |
+
await createRepo({
|
| 34 |
+
repo: repoId,
|
| 35 |
+
accessToken: token,
|
| 36 |
+
private: false,
|
| 37 |
+
});
|
| 38 |
+
|
| 39 |
+
return NextResponse.json({ success: true, repoId });
|
| 40 |
+
} catch (error: any) {
|
| 41 |
+
if (error.message?.includes('already exists')) {
|
| 42 |
+
return NextResponse.json({ success: true, repoId: `${namespace}/${datasetName}`, exists: true });
|
| 43 |
+
}
|
| 44 |
+
return NextResponse.json({ error: error.message || 'Failed to create dataset' }, { status: 500 });
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
case 'uploadDataset':
|
| 48 |
+
try {
|
| 49 |
+
if (!namespace || !datasetName || !datasetPath) {
|
| 50 |
+
return NextResponse.json({ error: 'Missing required parameters' }, { status: 400 });
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
const repoId = `datasets/${namespace}/${datasetName}`;
|
| 54 |
+
|
| 55 |
+
// Check if directory exists
|
| 56 |
+
try {
|
| 57 |
+
await stat(datasetPath);
|
| 58 |
+
} catch {
|
| 59 |
+
return NextResponse.json({ error: 'Dataset path does not exist' }, { status: 400 });
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// Read files from directory and upload them
|
| 63 |
+
const files = await readdir(datasetPath);
|
| 64 |
+
const filesToUpload = [];
|
| 65 |
+
|
| 66 |
+
for (const fileName of files) {
|
| 67 |
+
const filePath = path.join(datasetPath, fileName);
|
| 68 |
+
const fileStats = await stat(filePath);
|
| 69 |
+
|
| 70 |
+
if (fileStats.isFile()) {
|
| 71 |
+
filesToUpload.push({
|
| 72 |
+
path: fileName,
|
| 73 |
+
content: new URL(`file://${filePath}`)
|
| 74 |
+
});
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
if (filesToUpload.length > 0) {
|
| 79 |
+
await uploadFiles({
|
| 80 |
+
repo: repoId,
|
| 81 |
+
accessToken: token,
|
| 82 |
+
files: filesToUpload,
|
| 83 |
+
});
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return NextResponse.json({ success: true, repoId });
|
| 87 |
+
} catch (error: any) {
|
| 88 |
+
console.error('Upload error:', error);
|
| 89 |
+
return NextResponse.json({ error: error.message || 'Failed to upload dataset' }, { status: 500 });
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
case 'listFiles':
|
| 93 |
+
try {
|
| 94 |
+
if (!datasetPath) {
|
| 95 |
+
return NextResponse.json({ error: 'Dataset path required' }, { status: 400 });
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
const files = await readdir(datasetPath, { withFileTypes: true });
|
| 99 |
+
const imageExtensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp'];
|
| 100 |
+
|
| 101 |
+
const imageFiles = files
|
| 102 |
+
.filter(file => file.isFile())
|
| 103 |
+
.filter(file => imageExtensions.some(ext => file.name.toLowerCase().endsWith(ext)))
|
| 104 |
+
.map(file => ({
|
| 105 |
+
name: file.name,
|
| 106 |
+
path: path.join(datasetPath, file.name),
|
| 107 |
+
}));
|
| 108 |
+
|
| 109 |
+
const captionFiles = files
|
| 110 |
+
.filter(file => file.isFile())
|
| 111 |
+
.filter(file => file.name.endsWith('.txt'))
|
| 112 |
+
.map(file => ({
|
| 113 |
+
name: file.name,
|
| 114 |
+
path: path.join(datasetPath, file.name),
|
| 115 |
+
}));
|
| 116 |
+
|
| 117 |
+
return NextResponse.json({
|
| 118 |
+
images: imageFiles,
|
| 119 |
+
captions: captionFiles,
|
| 120 |
+
total: imageFiles.length
|
| 121 |
+
});
|
| 122 |
+
} catch (error: any) {
|
| 123 |
+
return NextResponse.json({ error: error.message || 'Failed to list files' }, { status: 500 });
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
case 'validateDataset':
|
| 127 |
+
try {
|
| 128 |
+
if (!datasetId) {
|
| 129 |
+
return NextResponse.json({ error: 'Dataset ID required' }, { status: 400 });
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Try to get dataset info to validate it exists and is accessible
|
| 133 |
+
const dataset = await datasetInfo({
|
| 134 |
+
name: datasetId,
|
| 135 |
+
accessToken: token,
|
| 136 |
+
});
|
| 137 |
+
|
| 138 |
+
return NextResponse.json({
|
| 139 |
+
exists: true,
|
| 140 |
+
dataset: {
|
| 141 |
+
id: dataset.id,
|
| 142 |
+
author: dataset.author,
|
| 143 |
+
downloads: dataset.downloads,
|
| 144 |
+
likes: dataset.likes,
|
| 145 |
+
private: dataset.private,
|
| 146 |
+
}
|
| 147 |
+
});
|
| 148 |
+
} catch (error: any) {
|
| 149 |
+
if (error.message?.includes('404') || error.message?.includes('not found')) {
|
| 150 |
+
return NextResponse.json({ exists: false }, { status: 200 });
|
| 151 |
+
}
|
| 152 |
+
if (error.message?.includes('401') || error.message?.includes('403')) {
|
| 153 |
+
return NextResponse.json({ error: 'Dataset not accessible with current token' }, { status: 403 });
|
| 154 |
+
}
|
| 155 |
+
return NextResponse.json({ error: error.message || 'Failed to validate dataset' }, { status: 500 });
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
default:
|
| 159 |
+
return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
|
| 160 |
+
}
|
| 161 |
+
} catch (error: any) {
|
| 162 |
+
console.error('HF Hub API error:', error);
|
| 163 |
+
return NextResponse.json({ error: error.message || 'Internal server error' }, { status: 500 });
|
| 164 |
+
}
|
| 165 |
+
}
|
src/app/api/hf-jobs/route.ts
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { spawn } from 'child_process';
|
| 3 |
+
import { writeFile } from 'fs/promises';
|
| 4 |
+
import path from 'path';
|
| 5 |
+
import { tmpdir } from 'os';
|
| 6 |
+
|
| 7 |
+
export async function POST(request: NextRequest) {
|
| 8 |
+
try {
|
| 9 |
+
const body = await request.json();
|
| 10 |
+
const { action, token, hardware, namespace, jobConfig, datasetRepo } = body;
|
| 11 |
+
|
| 12 |
+
switch (action) {
|
| 13 |
+
case 'checkStatus':
|
| 14 |
+
try {
|
| 15 |
+
if (!token || !jobConfig?.hf_job_id) {
|
| 16 |
+
return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 });
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id);
|
| 20 |
+
return NextResponse.json({ status: jobStatus });
|
| 21 |
+
} catch (error: any) {
|
| 22 |
+
console.error('Job status check error:', error);
|
| 23 |
+
return NextResponse.json({ error: error.message }, { status: 500 });
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
case 'generateScript':
|
| 27 |
+
try {
|
| 28 |
+
const uvScript = generateUVScript({
|
| 29 |
+
jobConfig,
|
| 30 |
+
datasetRepo,
|
| 31 |
+
namespace,
|
| 32 |
+
token: token || 'YOUR_HF_TOKEN',
|
| 33 |
+
});
|
| 34 |
+
|
| 35 |
+
return NextResponse.json({
|
| 36 |
+
script: uvScript,
|
| 37 |
+
filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py`
|
| 38 |
+
});
|
| 39 |
+
} catch (error: any) {
|
| 40 |
+
return NextResponse.json({ error: error.message }, { status: 500 });
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
case 'submitJob':
|
| 44 |
+
try {
|
| 45 |
+
if (!token || !hardware) {
|
| 46 |
+
return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 });
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// Generate UV script
|
| 50 |
+
const uvScript = generateUVScript({
|
| 51 |
+
jobConfig,
|
| 52 |
+
datasetRepo,
|
| 53 |
+
namespace,
|
| 54 |
+
token,
|
| 55 |
+
});
|
| 56 |
+
|
| 57 |
+
// Write script to temporary file
|
| 58 |
+
const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`);
|
| 59 |
+
await writeFile(scriptPath, uvScript);
|
| 60 |
+
|
| 61 |
+
// Submit HF job using uv run
|
| 62 |
+
const jobId = await submitHFJobUV(token, hardware, scriptPath);
|
| 63 |
+
|
| 64 |
+
return NextResponse.json({
|
| 65 |
+
success: true,
|
| 66 |
+
jobId,
|
| 67 |
+
message: `Job submitted successfully with ID: ${jobId}`
|
| 68 |
+
});
|
| 69 |
+
} catch (error: any) {
|
| 70 |
+
console.error('Job submission error:', error);
|
| 71 |
+
return NextResponse.json({ error: error.message }, { status: 500 });
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
default:
|
| 75 |
+
return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
|
| 76 |
+
}
|
| 77 |
+
} catch (error: any) {
|
| 78 |
+
console.error('HF Jobs API error:', error);
|
| 79 |
+
return NextResponse.json({ error: error.message }, { status: 500 });
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
function generateUVScript({ jobConfig, datasetRepo, namespace, token }: {
|
| 84 |
+
jobConfig: any;
|
| 85 |
+
datasetRepo: string;
|
| 86 |
+
namespace: string;
|
| 87 |
+
token: string;
|
| 88 |
+
}) {
|
| 89 |
+
const config = jobConfig.config;
|
| 90 |
+
const process = config.process[0];
|
| 91 |
+
|
| 92 |
+
return `# /// script
|
| 93 |
+
# dependencies = [
|
| 94 |
+
# "torch>=2.0.0",
|
| 95 |
+
# "torchvision",
|
| 96 |
+
# "torchao==0.10.0",
|
| 97 |
+
# "safetensors",
|
| 98 |
+
# "diffusers @ git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63",
|
| 99 |
+
# "transformers==4.52.4",
|
| 100 |
+
# "lycoris-lora==1.8.3",
|
| 101 |
+
# "flatten_json",
|
| 102 |
+
# "pyyaml",
|
| 103 |
+
# "oyaml",
|
| 104 |
+
# "tensorboard",
|
| 105 |
+
# "kornia",
|
| 106 |
+
# "invisible-watermark",
|
| 107 |
+
# "einops",
|
| 108 |
+
# "accelerate",
|
| 109 |
+
# "toml",
|
| 110 |
+
# "albumentations==1.4.15",
|
| 111 |
+
# "albucore==0.0.16",
|
| 112 |
+
# "pydantic",
|
| 113 |
+
# "omegaconf",
|
| 114 |
+
# "k-diffusion",
|
| 115 |
+
# "open_clip_torch",
|
| 116 |
+
# "timm",
|
| 117 |
+
# "prodigyopt",
|
| 118 |
+
# "controlnet_aux==0.0.10",
|
| 119 |
+
# "python-dotenv",
|
| 120 |
+
# "bitsandbytes",
|
| 121 |
+
# "hf_transfer",
|
| 122 |
+
# "lpips",
|
| 123 |
+
# "pytorch_fid",
|
| 124 |
+
# "optimum-quanto==0.2.4",
|
| 125 |
+
# "sentencepiece",
|
| 126 |
+
# "huggingface_hub",
|
| 127 |
+
# "peft",
|
| 128 |
+
# "python-slugify",
|
| 129 |
+
# "opencv-python-headless",
|
| 130 |
+
# "pytorch-wavelets==1.3.0",
|
| 131 |
+
# "matplotlib==3.10.1",
|
| 132 |
+
# "setuptools==69.5.1",
|
| 133 |
+
# "datasets==4.0.0",
|
| 134 |
+
# "pyarrow==20.0.0",
|
| 135 |
+
# "pillow",
|
| 136 |
+
# "ftfy",
|
| 137 |
+
# ]
|
| 138 |
+
# ///
|
| 139 |
+
|
| 140 |
+
import os
|
| 141 |
+
import sys
|
| 142 |
+
import subprocess
|
| 143 |
+
import argparse
|
| 144 |
+
import oyaml as yaml
|
| 145 |
+
from datasets import load_dataset
|
| 146 |
+
from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download
|
| 147 |
+
import tempfile
|
| 148 |
+
import shutil
|
| 149 |
+
import glob
|
| 150 |
+
from PIL import Image
|
| 151 |
+
|
| 152 |
+
def setup_ai_toolkit():
|
| 153 |
+
"""Clone and setup ai-toolkit repository"""
|
| 154 |
+
repo_dir = "ai-toolkit"
|
| 155 |
+
if not os.path.exists(repo_dir):
|
| 156 |
+
print("Cloning ai-toolkit repository...")
|
| 157 |
+
subprocess.run(
|
| 158 |
+
["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir],
|
| 159 |
+
check=True
|
| 160 |
+
)
|
| 161 |
+
sys.path.insert(0, os.path.abspath(repo_dir))
|
| 162 |
+
return repo_dir
|
| 163 |
+
|
| 164 |
+
def download_dataset(dataset_repo: str, local_path: str):
|
| 165 |
+
"""Download dataset from HF Hub as files"""
|
| 166 |
+
print(f"Downloading dataset from {dataset_repo}...")
|
| 167 |
+
|
| 168 |
+
# Create local dataset directory
|
| 169 |
+
os.makedirs(local_path, exist_ok=True)
|
| 170 |
+
|
| 171 |
+
# Use snapshot_download to get the dataset files directly
|
| 172 |
+
from huggingface_hub import snapshot_download
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
# First try to download as a structured dataset
|
| 176 |
+
dataset = load_dataset(dataset_repo, split="train")
|
| 177 |
+
|
| 178 |
+
# Download images and captions from structured dataset
|
| 179 |
+
for i, item in enumerate(dataset):
|
| 180 |
+
# Save image
|
| 181 |
+
if "image" in item:
|
| 182 |
+
image_path = os.path.join(local_path, f"image_{i:06d}.jpg")
|
| 183 |
+
image = item["image"]
|
| 184 |
+
|
| 185 |
+
# Convert RGBA to RGB if necessary (for JPEG compatibility)
|
| 186 |
+
if image.mode == 'RGBA':
|
| 187 |
+
# Create a white background and paste the RGBA image on it
|
| 188 |
+
background = Image.new('RGB', image.size, (255, 255, 255))
|
| 189 |
+
background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask
|
| 190 |
+
image = background
|
| 191 |
+
elif image.mode not in ['RGB', 'L']:
|
| 192 |
+
# Convert any other mode to RGB
|
| 193 |
+
image = image.convert('RGB')
|
| 194 |
+
|
| 195 |
+
image.save(image_path, 'JPEG')
|
| 196 |
+
|
| 197 |
+
# Save caption
|
| 198 |
+
if "text" in item:
|
| 199 |
+
caption_path = os.path.join(local_path, f"image_{i:06d}.txt")
|
| 200 |
+
with open(caption_path, "w", encoding="utf-8") as f:
|
| 201 |
+
f.write(item["text"])
|
| 202 |
+
|
| 203 |
+
print(f"Downloaded {len(dataset)} items to {local_path}")
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"Failed to load as structured dataset: {e}")
|
| 207 |
+
print("Attempting to download raw files...")
|
| 208 |
+
|
| 209 |
+
# Download the dataset repository as files
|
| 210 |
+
temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset")
|
| 211 |
+
|
| 212 |
+
# Copy all image and text files to the local path
|
| 213 |
+
import glob
|
| 214 |
+
import shutil
|
| 215 |
+
|
| 216 |
+
print(f"Downloaded repo to: {temp_repo_path}")
|
| 217 |
+
print(f"Contents: {os.listdir(temp_repo_path)}")
|
| 218 |
+
|
| 219 |
+
# Find all image files
|
| 220 |
+
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG']
|
| 221 |
+
image_files = []
|
| 222 |
+
for ext in image_extensions:
|
| 223 |
+
pattern = os.path.join(temp_repo_path, "**", ext)
|
| 224 |
+
found_files = glob.glob(pattern, recursive=True)
|
| 225 |
+
image_files.extend(found_files)
|
| 226 |
+
print(f"Pattern {pattern} found {len(found_files)} files")
|
| 227 |
+
|
| 228 |
+
# Find all text files
|
| 229 |
+
text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True)
|
| 230 |
+
|
| 231 |
+
print(f"Found {len(image_files)} image files and {len(text_files)} text files")
|
| 232 |
+
|
| 233 |
+
# Copy image files
|
| 234 |
+
for i, img_file in enumerate(image_files):
|
| 235 |
+
dest_path = os.path.join(local_path, f"image_{i:06d}.jpg")
|
| 236 |
+
|
| 237 |
+
# Load and convert image if needed
|
| 238 |
+
try:
|
| 239 |
+
with Image.open(img_file) as image:
|
| 240 |
+
if image.mode == 'RGBA':
|
| 241 |
+
background = Image.new('RGB', image.size, (255, 255, 255))
|
| 242 |
+
background.paste(image, mask=image.split()[-1])
|
| 243 |
+
image = background
|
| 244 |
+
elif image.mode not in ['RGB', 'L']:
|
| 245 |
+
image = image.convert('RGB')
|
| 246 |
+
|
| 247 |
+
image.save(dest_path, 'JPEG')
|
| 248 |
+
except Exception as img_error:
|
| 249 |
+
print(f"Error processing image {img_file}: {img_error}")
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
# Copy text files (captions)
|
| 253 |
+
for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images
|
| 254 |
+
dest_path = os.path.join(local_path, f"image_{i:06d}.txt")
|
| 255 |
+
try:
|
| 256 |
+
shutil.copy2(txt_file, dest_path)
|
| 257 |
+
except Exception as txt_error:
|
| 258 |
+
print(f"Error copying text file {txt_file}: {txt_error}")
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}")
|
| 262 |
+
|
| 263 |
+
def create_config(dataset_path: str, output_path: str):
|
| 264 |
+
"""Create training configuration"""
|
| 265 |
+
import json
|
| 266 |
+
|
| 267 |
+
# Load config from JSON string and fix boolean/null values for Python
|
| 268 |
+
config_str = """${JSON.stringify(jobConfig, null, 2)}"""
|
| 269 |
+
config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None')
|
| 270 |
+
config = eval(config_str)
|
| 271 |
+
|
| 272 |
+
# Update paths for cloud environment
|
| 273 |
+
config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_path
|
| 274 |
+
config["config"]["process"][0]["training_folder"] = output_path
|
| 275 |
+
|
| 276 |
+
# Remove sqlite_db_path as it's not needed for cloud training
|
| 277 |
+
if "sqlite_db_path" in config["config"]["process"][0]:
|
| 278 |
+
del config["config"]["process"][0]["sqlite_db_path"]
|
| 279 |
+
|
| 280 |
+
# Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies
|
| 281 |
+
if config["config"]["process"][0]["type"] == "ui_trainer":
|
| 282 |
+
config["config"]["process"][0]["type"] = "sd_trainer"
|
| 283 |
+
|
| 284 |
+
return config
|
| 285 |
+
|
| 286 |
+
def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict):
|
| 287 |
+
"""Upload trained model to HF Hub with README generation and proper file organization"""
|
| 288 |
+
import tempfile
|
| 289 |
+
import shutil
|
| 290 |
+
import glob
|
| 291 |
+
import re
|
| 292 |
+
import yaml
|
| 293 |
+
from datetime import datetime
|
| 294 |
+
from huggingface_hub import create_repo, upload_file, HfApi
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
repo_id = f"{namespace}/{model_name}"
|
| 298 |
+
|
| 299 |
+
# Create repository
|
| 300 |
+
create_repo(repo_id=repo_id, token=token, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
print(f"Uploading model to {repo_id}...")
|
| 303 |
+
|
| 304 |
+
# Create temporary directory for organized upload
|
| 305 |
+
with tempfile.TemporaryDirectory() as temp_upload_dir:
|
| 306 |
+
api = HfApi()
|
| 307 |
+
|
| 308 |
+
# 1. Find and upload model files to root directory
|
| 309 |
+
safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True)
|
| 310 |
+
json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True)
|
| 311 |
+
txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True)
|
| 312 |
+
|
| 313 |
+
uploaded_files = []
|
| 314 |
+
|
| 315 |
+
# Upload .safetensors files to root
|
| 316 |
+
for file_path in safetensors_files:
|
| 317 |
+
filename = os.path.basename(file_path)
|
| 318 |
+
print(f"Uploading {filename} to repository root...")
|
| 319 |
+
api.upload_file(
|
| 320 |
+
path_or_fileobj=file_path,
|
| 321 |
+
path_in_repo=filename,
|
| 322 |
+
repo_id=repo_id,
|
| 323 |
+
token=token
|
| 324 |
+
)
|
| 325 |
+
uploaded_files.append(filename)
|
| 326 |
+
|
| 327 |
+
# Upload relevant JSON config files to root (skip metadata.json and other internal files)
|
| 328 |
+
config_files_uploaded = []
|
| 329 |
+
for file_path in json_files:
|
| 330 |
+
filename = os.path.basename(file_path)
|
| 331 |
+
# Only upload important config files, skip internal metadata
|
| 332 |
+
if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']):
|
| 333 |
+
print(f"Uploading {filename} to repository root...")
|
| 334 |
+
api.upload_file(
|
| 335 |
+
path_or_fileobj=file_path,
|
| 336 |
+
path_in_repo=filename,
|
| 337 |
+
repo_id=repo_id,
|
| 338 |
+
token=token
|
| 339 |
+
)
|
| 340 |
+
uploaded_files.append(filename)
|
| 341 |
+
config_files_uploaded.append(filename)
|
| 342 |
+
|
| 343 |
+
# 2. Handle sample images
|
| 344 |
+
samples_uploaded = []
|
| 345 |
+
samples_dir = os.path.join(output_path, "samples")
|
| 346 |
+
if os.path.isdir(samples_dir):
|
| 347 |
+
print("Uploading sample images...")
|
| 348 |
+
# Create samples directory in repo
|
| 349 |
+
for filename in os.listdir(samples_dir):
|
| 350 |
+
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
|
| 351 |
+
file_path = os.path.join(samples_dir, filename)
|
| 352 |
+
repo_path = f"samples/{filename}"
|
| 353 |
+
api.upload_file(
|
| 354 |
+
path_or_fileobj=file_path,
|
| 355 |
+
path_in_repo=repo_path,
|
| 356 |
+
repo_id=repo_id,
|
| 357 |
+
token=token
|
| 358 |
+
)
|
| 359 |
+
samples_uploaded.append(repo_path)
|
| 360 |
+
|
| 361 |
+
# 3. Generate and upload README.md
|
| 362 |
+
readme_content = generate_model_card_readme(
|
| 363 |
+
repo_id=repo_id,
|
| 364 |
+
config=config,
|
| 365 |
+
model_name=model_name,
|
| 366 |
+
samples_dir=samples_dir if os.path.isdir(samples_dir) else None,
|
| 367 |
+
uploaded_files=uploaded_files
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Create README.md file and upload to root
|
| 371 |
+
readme_path = os.path.join(temp_upload_dir, "README.md")
|
| 372 |
+
with open(readme_path, "w", encoding="utf-8") as f:
|
| 373 |
+
f.write(readme_content)
|
| 374 |
+
|
| 375 |
+
print("Uploading README.md to repository root...")
|
| 376 |
+
api.upload_file(
|
| 377 |
+
path_or_fileobj=readme_path,
|
| 378 |
+
path_in_repo="README.md",
|
| 379 |
+
repo_id=repo_id,
|
| 380 |
+
token=token
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
print(f"Model uploaded successfully to https://huggingface.co/{repo_id}")
|
| 384 |
+
print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md")
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
print(f"Failed to upload model: {e}")
|
| 388 |
+
raise e
|
| 389 |
+
|
| 390 |
+
def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str:
|
| 391 |
+
"""Generate README.md content for the model card based on AI Toolkit's implementation"""
|
| 392 |
+
import re
|
| 393 |
+
import yaml
|
| 394 |
+
import os
|
| 395 |
+
|
| 396 |
+
try:
|
| 397 |
+
# Extract configuration details
|
| 398 |
+
process_config = config.get("config", {}).get("process", [{}])[0]
|
| 399 |
+
model_config = process_config.get("model", {})
|
| 400 |
+
train_config = process_config.get("train", {})
|
| 401 |
+
sample_config = process_config.get("sample", {})
|
| 402 |
+
|
| 403 |
+
# Gather model info
|
| 404 |
+
base_model = model_config.get("name_or_path", "unknown")
|
| 405 |
+
trigger_word = process_config.get("trigger_word")
|
| 406 |
+
arch = model_config.get("arch", "")
|
| 407 |
+
|
| 408 |
+
# Determine license based on base model
|
| 409 |
+
if "FLUX.1-schnell" in base_model:
|
| 410 |
+
license_info = {"license": "apache-2.0"}
|
| 411 |
+
elif "FLUX.1-dev" in base_model:
|
| 412 |
+
license_info = {
|
| 413 |
+
"license": "other",
|
| 414 |
+
"license_name": "flux-1-dev-non-commercial-license",
|
| 415 |
+
"license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
|
| 416 |
+
}
|
| 417 |
+
else:
|
| 418 |
+
license_info = {"license": "creativeml-openrail-m"}
|
| 419 |
+
|
| 420 |
+
# Generate tags based on model architecture
|
| 421 |
+
tags = ["text-to-image"]
|
| 422 |
+
|
| 423 |
+
if "xl" in arch.lower():
|
| 424 |
+
tags.append("stable-diffusion-xl")
|
| 425 |
+
if "flux" in arch.lower():
|
| 426 |
+
tags.append("flux")
|
| 427 |
+
if "lumina" in arch.lower():
|
| 428 |
+
tags.append("lumina2")
|
| 429 |
+
if "sd3" in arch.lower() or "v3" in arch.lower():
|
| 430 |
+
tags.append("sd3")
|
| 431 |
+
|
| 432 |
+
# Add LoRA-specific tags
|
| 433 |
+
tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"])
|
| 434 |
+
|
| 435 |
+
# Generate widgets from sample images and prompts
|
| 436 |
+
widgets = []
|
| 437 |
+
if samples_dir and os.path.isdir(samples_dir):
|
| 438 |
+
sample_prompts = sample_config.get("samples", [])
|
| 439 |
+
if not sample_prompts:
|
| 440 |
+
# Fallback to old format
|
| 441 |
+
sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])]
|
| 442 |
+
|
| 443 |
+
# Get sample image files
|
| 444 |
+
sample_files = []
|
| 445 |
+
if os.path.isdir(samples_dir):
|
| 446 |
+
for filename in os.listdir(samples_dir):
|
| 447 |
+
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
|
| 448 |
+
# Parse filename pattern: timestamp__steps_index.jpg
|
| 449 |
+
match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
|
| 450 |
+
if match:
|
| 451 |
+
steps, index = int(match.group(1)), int(match.group(2))
|
| 452 |
+
# Only use samples from final training step
|
| 453 |
+
final_steps = train_config.get("steps", 1000)
|
| 454 |
+
if steps == final_steps:
|
| 455 |
+
sample_files.append((index, f"samples/{filename}"))
|
| 456 |
+
|
| 457 |
+
# Sort by index and create widgets
|
| 458 |
+
sample_files.sort(key=lambda x: x[0])
|
| 459 |
+
|
| 460 |
+
for i, prompt_obj in enumerate(sample_prompts):
|
| 461 |
+
prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj)
|
| 462 |
+
if i < len(sample_files):
|
| 463 |
+
_, image_path = sample_files[i]
|
| 464 |
+
widgets.append({
|
| 465 |
+
"text": prompt,
|
| 466 |
+
"output": {"url": image_path}
|
| 467 |
+
})
|
| 468 |
+
|
| 469 |
+
# Determine torch dtype based on model
|
| 470 |
+
dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16"
|
| 471 |
+
|
| 472 |
+
# Find the main safetensors file for usage example
|
| 473 |
+
main_safetensors = f"{model_name}.safetensors"
|
| 474 |
+
if uploaded_files:
|
| 475 |
+
safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')]
|
| 476 |
+
if safetensors_files:
|
| 477 |
+
main_safetensors = safetensors_files[0]
|
| 478 |
+
|
| 479 |
+
# Construct YAML frontmatter
|
| 480 |
+
frontmatter = {
|
| 481 |
+
"tags": tags,
|
| 482 |
+
"base_model": base_model,
|
| 483 |
+
**license_info
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
if widgets:
|
| 487 |
+
frontmatter["widget"] = widgets
|
| 488 |
+
|
| 489 |
+
if trigger_word:
|
| 490 |
+
frontmatter["instance_prompt"] = trigger_word
|
| 491 |
+
|
| 492 |
+
# Get first prompt for usage example
|
| 493 |
+
usage_prompt = trigger_word or "a beautiful landscape"
|
| 494 |
+
if widgets:
|
| 495 |
+
usage_prompt = widgets[0]["text"]
|
| 496 |
+
elif trigger_word:
|
| 497 |
+
usage_prompt = trigger_word
|
| 498 |
+
|
| 499 |
+
# Construct README content
|
| 500 |
+
trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined."
|
| 501 |
+
|
| 502 |
+
# Build YAML frontmatter string
|
| 503 |
+
frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip()
|
| 504 |
+
|
| 505 |
+
readme_content = f"""---
|
| 506 |
+
{frontmatter_yaml}
|
| 507 |
+
---
|
| 508 |
+
|
| 509 |
+
# {model_name}
|
| 510 |
+
|
| 511 |
+
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
|
| 512 |
+
|
| 513 |
+
<Gallery />
|
| 514 |
+
|
| 515 |
+
## Trigger words
|
| 516 |
+
|
| 517 |
+
{trigger_section}
|
| 518 |
+
|
| 519 |
+
## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc.
|
| 520 |
+
|
| 521 |
+
Weights for this model are available in Safetensors format.
|
| 522 |
+
|
| 523 |
+
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
| 524 |
+
|
| 525 |
+
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
| 526 |
+
|
| 527 |
+
\`\`\`py
|
| 528 |
+
from diffusers import AutoPipelineForText2Image
|
| 529 |
+
import torch
|
| 530 |
+
|
| 531 |
+
pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda')
|
| 532 |
+
pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}')
|
| 533 |
+
image = pipeline('{usage_prompt}').images[0]
|
| 534 |
+
image.save("my_image.png")
|
| 535 |
+
\`\`\`
|
| 536 |
+
|
| 537 |
+
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
| 538 |
+
|
| 539 |
+
"""
|
| 540 |
+
return readme_content
|
| 541 |
+
|
| 542 |
+
except Exception as e:
|
| 543 |
+
print(f"Error generating README: {e}")
|
| 544 |
+
# Fallback simple README
|
| 545 |
+
return f"""# {model_name}
|
| 546 |
+
|
| 547 |
+
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
|
| 548 |
+
|
| 549 |
+
## Download model
|
| 550 |
+
|
| 551 |
+
Weights for this model are available in Safetensors format.
|
| 552 |
+
|
| 553 |
+
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
def main():
|
| 557 |
+
# Setup environment - token comes from HF Jobs secrets
|
| 558 |
+
if "HF_TOKEN" not in os.environ:
|
| 559 |
+
raise ValueError("HF_TOKEN environment variable not set")
|
| 560 |
+
|
| 561 |
+
# Install system dependencies for headless operation
|
| 562 |
+
print("Installing system dependencies...")
|
| 563 |
+
try:
|
| 564 |
+
subprocess.run(["apt-get", "update"], check=True, capture_output=True)
|
| 565 |
+
subprocess.run([
|
| 566 |
+
"apt-get", "install", "-y",
|
| 567 |
+
"libgl1-mesa-glx",
|
| 568 |
+
"libglib2.0-0",
|
| 569 |
+
"libsm6",
|
| 570 |
+
"libxext6",
|
| 571 |
+
"libxrender-dev",
|
| 572 |
+
"libgomp1",
|
| 573 |
+
"ffmpeg"
|
| 574 |
+
], check=True, capture_output=True)
|
| 575 |
+
print("System dependencies installed successfully")
|
| 576 |
+
except subprocess.CalledProcessError as e:
|
| 577 |
+
print(f"Failed to install system dependencies: {e}")
|
| 578 |
+
print("Continuing without system dependencies...")
|
| 579 |
+
|
| 580 |
+
# Setup ai-toolkit
|
| 581 |
+
toolkit_dir = setup_ai_toolkit()
|
| 582 |
+
|
| 583 |
+
# Create temporary directories
|
| 584 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 585 |
+
dataset_path = os.path.join(temp_dir, "dataset")
|
| 586 |
+
output_path = os.path.join(temp_dir, "output")
|
| 587 |
+
|
| 588 |
+
# Download dataset
|
| 589 |
+
download_dataset("${datasetRepo}", dataset_path)
|
| 590 |
+
|
| 591 |
+
# Create config
|
| 592 |
+
config = create_config(dataset_path, output_path)
|
| 593 |
+
config_path = os.path.join(temp_dir, "config.yaml")
|
| 594 |
+
|
| 595 |
+
with open(config_path, "w") as f:
|
| 596 |
+
yaml.dump(config, f, default_flow_style=False)
|
| 597 |
+
|
| 598 |
+
# Run training
|
| 599 |
+
print("Starting training...")
|
| 600 |
+
os.chdir(toolkit_dir)
|
| 601 |
+
|
| 602 |
+
subprocess.run([
|
| 603 |
+
sys.executable, "run.py",
|
| 604 |
+
config_path
|
| 605 |
+
], check=True)
|
| 606 |
+
|
| 607 |
+
print("Training completed!")
|
| 608 |
+
|
| 609 |
+
# Upload results
|
| 610 |
+
model_name = f"${jobConfig.config.name}-lora"
|
| 611 |
+
upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config)
|
| 612 |
+
|
| 613 |
+
if __name__ == "__main__":
|
| 614 |
+
main()
|
| 615 |
+
`;
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
async function submitHFJobUV(token: string, hardware: string, scriptPath: string): Promise<string> {
|
| 619 |
+
return new Promise((resolve, reject) => {
|
| 620 |
+
// Ensure token is available
|
| 621 |
+
if (!token) {
|
| 622 |
+
reject(new Error('HF_TOKEN is required'));
|
| 623 |
+
return;
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
console.log('Setting up environment with HF_TOKEN for job submission');
|
| 627 |
+
console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach ${scriptPath}`);
|
| 628 |
+
|
| 629 |
+
// Use hf jobs uv run command with timeout and detach to get job ID
|
| 630 |
+
const childProcess = spawn('hf', [
|
| 631 |
+
'jobs', 'uv', 'run',
|
| 632 |
+
'--flavor', hardware,
|
| 633 |
+
'--timeout', '5h',
|
| 634 |
+
'--secrets', 'HF_TOKEN',
|
| 635 |
+
'--detach',
|
| 636 |
+
scriptPath
|
| 637 |
+
], {
|
| 638 |
+
env: {
|
| 639 |
+
...process.env,
|
| 640 |
+
HF_TOKEN: token
|
| 641 |
+
}
|
| 642 |
+
});
|
| 643 |
+
|
| 644 |
+
let output = '';
|
| 645 |
+
let error = '';
|
| 646 |
+
|
| 647 |
+
childProcess.stdout.on('data', (data) => {
|
| 648 |
+
const text = data.toString();
|
| 649 |
+
output += text;
|
| 650 |
+
console.log('HF Jobs stdout:', text);
|
| 651 |
+
});
|
| 652 |
+
|
| 653 |
+
childProcess.stderr.on('data', (data) => {
|
| 654 |
+
const text = data.toString();
|
| 655 |
+
error += text;
|
| 656 |
+
console.log('HF Jobs stderr:', text);
|
| 657 |
+
});
|
| 658 |
+
|
| 659 |
+
childProcess.on('close', (code) => {
|
| 660 |
+
console.log('HF Jobs process closed with code:', code);
|
| 661 |
+
console.log('Full output:', output);
|
| 662 |
+
console.log('Full error:', error);
|
| 663 |
+
|
| 664 |
+
if (code === 0) {
|
| 665 |
+
// With --detach flag, the output should be just the job ID
|
| 666 |
+
const fullText = (output + ' ' + error).trim();
|
| 667 |
+
|
| 668 |
+
// Updated patterns to handle variable-length hex job IDs (16-24+ characters)
|
| 669 |
+
const jobIdPatterns = [
|
| 670 |
+
/Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac"
|
| 671 |
+
/job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac"
|
| 672 |
+
/Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac"
|
| 673 |
+
/created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac"
|
| 674 |
+
/submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac"
|
| 675 |
+
/https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern
|
| 676 |
+
/([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string
|
| 677 |
+
];
|
| 678 |
+
|
| 679 |
+
let jobId = 'unknown';
|
| 680 |
+
|
| 681 |
+
for (const pattern of jobIdPatterns) {
|
| 682 |
+
const match = fullText.match(pattern);
|
| 683 |
+
if (match && match[1] && match[1] !== 'started') {
|
| 684 |
+
jobId = match[1];
|
| 685 |
+
console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`);
|
| 686 |
+
break;
|
| 687 |
+
}
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
resolve(jobId);
|
| 691 |
+
} else {
|
| 692 |
+
reject(new Error(error || output || 'Failed to submit job'));
|
| 693 |
+
}
|
| 694 |
+
});
|
| 695 |
+
|
| 696 |
+
childProcess.on('error', (err) => {
|
| 697 |
+
console.error('HF Jobs process error:', err);
|
| 698 |
+
reject(new Error(`Process error: ${err.message}`));
|
| 699 |
+
});
|
| 700 |
+
});
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
async function checkHFJobStatus(token: string, jobId: string): Promise<any> {
|
| 704 |
+
return new Promise((resolve, reject) => {
|
| 705 |
+
console.log(`Checking HF Job status for: ${jobId}`);
|
| 706 |
+
|
| 707 |
+
const childProcess = spawn('hf', [
|
| 708 |
+
'jobs', 'inspect', jobId
|
| 709 |
+
], {
|
| 710 |
+
env: {
|
| 711 |
+
...process.env,
|
| 712 |
+
HF_TOKEN: token
|
| 713 |
+
}
|
| 714 |
+
});
|
| 715 |
+
|
| 716 |
+
let output = '';
|
| 717 |
+
let error = '';
|
| 718 |
+
|
| 719 |
+
childProcess.stdout.on('data', (data) => {
|
| 720 |
+
const text = data.toString();
|
| 721 |
+
output += text;
|
| 722 |
+
});
|
| 723 |
+
|
| 724 |
+
childProcess.stderr.on('data', (data) => {
|
| 725 |
+
const text = data.toString();
|
| 726 |
+
error += text;
|
| 727 |
+
});
|
| 728 |
+
|
| 729 |
+
childProcess.on('close', (code) => {
|
| 730 |
+
if (code === 0) {
|
| 731 |
+
try {
|
| 732 |
+
// Parse the JSON output from hf jobs inspect
|
| 733 |
+
const jobInfo = JSON.parse(output);
|
| 734 |
+
if (Array.isArray(jobInfo) && jobInfo.length > 0) {
|
| 735 |
+
const job = jobInfo[0];
|
| 736 |
+
resolve({
|
| 737 |
+
id: job.id,
|
| 738 |
+
status: job.status?.stage || 'UNKNOWN',
|
| 739 |
+
message: job.status?.message,
|
| 740 |
+
created_at: job.created_at,
|
| 741 |
+
flavor: job.flavor,
|
| 742 |
+
url: job.url,
|
| 743 |
+
});
|
| 744 |
+
} else {
|
| 745 |
+
reject(new Error('Invalid job info response'));
|
| 746 |
+
}
|
| 747 |
+
} catch (parseError: any) {
|
| 748 |
+
console.error('Failed to parse job status:', parseError, output);
|
| 749 |
+
reject(new Error('Failed to parse job status'));
|
| 750 |
+
}
|
| 751 |
+
} else {
|
| 752 |
+
reject(new Error(error || output || 'Failed to check job status'));
|
| 753 |
+
}
|
| 754 |
+
});
|
| 755 |
+
|
| 756 |
+
childProcess.on('error', (err) => {
|
| 757 |
+
console.error('HF Jobs inspect process error:', err);
|
| 758 |
+
reject(new Error(`Process error: ${err.message}`));
|
| 759 |
+
});
|
| 760 |
+
});
|
| 761 |
+
}
|
src/app/api/img/[...imagePath]/route.ts
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* eslint-disable */
|
| 2 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 3 |
+
import fs from 'fs';
|
| 4 |
+
import path from 'path';
|
| 5 |
+
import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings';
|
| 6 |
+
|
| 7 |
+
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
|
| 8 |
+
const { imagePath } = await params;
|
| 9 |
+
try {
|
| 10 |
+
// Decode the path
|
| 11 |
+
const filepath = decodeURIComponent(imagePath);
|
| 12 |
+
|
| 13 |
+
// Get allowed directories
|
| 14 |
+
const datasetRoot = await getDatasetsRoot();
|
| 15 |
+
const trainingRoot = await getTrainingFolder();
|
| 16 |
+
const dataRoot = await getDataRoot();
|
| 17 |
+
|
| 18 |
+
const allowedDirs = [datasetRoot, trainingRoot, dataRoot];
|
| 19 |
+
|
| 20 |
+
// Security check: Ensure path is in allowed directory
|
| 21 |
+
const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');
|
| 22 |
+
|
| 23 |
+
if (!isAllowed) {
|
| 24 |
+
console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`);
|
| 25 |
+
return new NextResponse('Access denied', { status: 403 });
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// Check if file exists
|
| 29 |
+
if (!fs.existsSync(filepath)) {
|
| 30 |
+
console.warn(`File not found: ${filepath}`);
|
| 31 |
+
return new NextResponse('File not found', { status: 404 });
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// Get file info
|
| 35 |
+
const stat = fs.statSync(filepath);
|
| 36 |
+
if (!stat.isFile()) {
|
| 37 |
+
return new NextResponse('Not a file', { status: 400 });
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Determine content type
|
| 41 |
+
const ext = path.extname(filepath).toLowerCase();
|
| 42 |
+
const contentTypeMap: { [key: string]: string } = {
|
| 43 |
+
// Images
|
| 44 |
+
'.jpg': 'image/jpeg',
|
| 45 |
+
'.jpeg': 'image/jpeg',
|
| 46 |
+
'.png': 'image/png',
|
| 47 |
+
'.gif': 'image/gif',
|
| 48 |
+
'.webp': 'image/webp',
|
| 49 |
+
'.svg': 'image/svg+xml',
|
| 50 |
+
'.bmp': 'image/bmp',
|
| 51 |
+
// Videos
|
| 52 |
+
'.mp4': 'video/mp4',
|
| 53 |
+
'.avi': 'video/x-msvideo',
|
| 54 |
+
'.mov': 'video/quicktime',
|
| 55 |
+
'.mkv': 'video/x-matroska',
|
| 56 |
+
'.wmv': 'video/x-ms-wmv',
|
| 57 |
+
'.m4v': 'video/x-m4v',
|
| 58 |
+
'.flv': 'video/x-flv'
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
const contentType = contentTypeMap[ext] || 'application/octet-stream';
|
| 62 |
+
|
| 63 |
+
// Read file as buffer
|
| 64 |
+
const fileBuffer = fs.readFileSync(filepath);
|
| 65 |
+
|
| 66 |
+
// Return file with appropriate headers
|
| 67 |
+
return new NextResponse(fileBuffer, {
|
| 68 |
+
headers: {
|
| 69 |
+
'Content-Type': contentType,
|
| 70 |
+
'Content-Length': String(stat.size),
|
| 71 |
+
'Cache-Control': 'public, max-age=86400',
|
| 72 |
+
},
|
| 73 |
+
});
|
| 74 |
+
} catch (error) {
|
| 75 |
+
console.error('Error serving image:', error);
|
| 76 |
+
return new NextResponse('Internal Server Error', { status: 500 });
|
| 77 |
+
}
|
| 78 |
+
}
|
src/app/api/img/caption/route.ts
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import fs from 'fs';
|
| 3 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 4 |
+
|
| 5 |
+
export async function POST(request: Request) {
|
| 6 |
+
try {
|
| 7 |
+
const body = await request.json();
|
| 8 |
+
const { imgPath, caption } = body;
|
| 9 |
+
let datasetsPath = await getDatasetsRoot();
|
| 10 |
+
// make sure the dataset path is in the image path
|
| 11 |
+
if (!imgPath.startsWith(datasetsPath)) {
|
| 12 |
+
return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
// if img doesnt exist, ignore
|
| 16 |
+
if (!fs.existsSync(imgPath)) {
|
| 17 |
+
return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// check for caption
|
| 21 |
+
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
|
| 22 |
+
// save caption to file
|
| 23 |
+
fs.writeFileSync(captionPath, caption);
|
| 24 |
+
|
| 25 |
+
return NextResponse.json({ success: true });
|
| 26 |
+
} catch (error) {
|
| 27 |
+
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
| 28 |
+
}
|
| 29 |
+
}
|
src/app/api/img/delete/route.ts
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import fs from 'fs';
|
| 3 |
+
import { getDatasetsRoot } from '@/server/settings';
|
| 4 |
+
|
| 5 |
+
export async function POST(request: Request) {
|
| 6 |
+
try {
|
| 7 |
+
const body = await request.json();
|
| 8 |
+
const { imgPath } = body;
|
| 9 |
+
let datasetsPath = await getDatasetsRoot();
|
| 10 |
+
// make sure the dataset path is in the image path
|
| 11 |
+
if (!imgPath.startsWith(datasetsPath)) {
|
| 12 |
+
return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
// if img doesnt exist, ignore
|
| 16 |
+
if (!fs.existsSync(imgPath)) {
|
| 17 |
+
return NextResponse.json({ success: true });
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// delete it and return success
|
| 21 |
+
fs.unlinkSync(imgPath);
|
| 22 |
+
|
| 23 |
+
// check for caption
|
| 24 |
+
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
|
| 25 |
+
if (fs.existsSync(captionPath)) {
|
| 26 |
+
// delete caption file
|
| 27 |
+
fs.unlinkSync(captionPath);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
return NextResponse.json({ success: true });
|
| 31 |
+
} catch (error) {
|
| 32 |
+
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
| 33 |
+
}
|
| 34 |
+
}
|
src/app/api/img/upload/route.ts
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/app/api/datasets/upload/route.ts
|
| 2 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 3 |
+
import { writeFile, mkdir } from 'fs/promises';
|
| 4 |
+
import { join } from 'path';
|
| 5 |
+
import { getDataRoot } from '@/server/settings';
|
| 6 |
+
import {v4 as uuidv4} from 'uuid';
|
| 7 |
+
|
| 8 |
+
export async function POST(request: NextRequest) {
|
| 9 |
+
try {
|
| 10 |
+
const dataRoot = await getDataRoot();
|
| 11 |
+
if (!dataRoot) {
|
| 12 |
+
return NextResponse.json({ error: 'Data root path not found' }, { status: 500 });
|
| 13 |
+
}
|
| 14 |
+
const imgRoot = join(dataRoot, 'images');
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
const formData = await request.formData();
|
| 18 |
+
const files = formData.getAll('files');
|
| 19 |
+
|
| 20 |
+
if (!files || files.length === 0) {
|
| 21 |
+
return NextResponse.json({ error: 'No files provided' }, { status: 400 });
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// make it recursive if it doesn't exist
|
| 25 |
+
await mkdir(imgRoot, { recursive: true });
|
| 26 |
+
const savedFiles = await Promise.all(
|
| 27 |
+
files.map(async (file: any) => {
|
| 28 |
+
const bytes = await file.arrayBuffer();
|
| 29 |
+
const buffer = Buffer.from(bytes);
|
| 30 |
+
|
| 31 |
+
const extension = file.name.split('.').pop() || 'jpg';
|
| 32 |
+
|
| 33 |
+
// Clean filename and ensure it's unique
|
| 34 |
+
const fileName = `${uuidv4()}`; // Use UUID for unique file names
|
| 35 |
+
const filePath = join(imgRoot, `${fileName}.${extension}`);
|
| 36 |
+
|
| 37 |
+
await writeFile(filePath, buffer);
|
| 38 |
+
return filePath;
|
| 39 |
+
}),
|
| 40 |
+
);
|
| 41 |
+
|
| 42 |
+
return NextResponse.json({
|
| 43 |
+
message: 'Files uploaded successfully',
|
| 44 |
+
files: savedFiles,
|
| 45 |
+
});
|
| 46 |
+
} catch (error) {
|
| 47 |
+
console.error('Upload error:', error);
|
| 48 |
+
return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// Increase payload size limit (default is 4mb)
|
| 53 |
+
export const config = {
|
| 54 |
+
api: {
|
| 55 |
+
bodyParser: false,
|
| 56 |
+
responseLimit: '50mb',
|
| 57 |
+
},
|
| 58 |
+
};
|
src/app/api/jobs/[jobID]/delete/route.ts
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
import { getTrainingFolder } from '@/server/settings';
|
| 4 |
+
import path from 'path';
|
| 5 |
+
import fs from 'fs';
|
| 6 |
+
|
| 7 |
+
const prisma = new PrismaClient();
|
| 8 |
+
|
| 9 |
+
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
| 10 |
+
const { jobID } = await params;
|
| 11 |
+
|
| 12 |
+
const job = await prisma.job.findUnique({
|
| 13 |
+
where: { id: jobID },
|
| 14 |
+
});
|
| 15 |
+
|
| 16 |
+
if (!job) {
|
| 17 |
+
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
const trainingRoot = await getTrainingFolder();
|
| 21 |
+
const trainingFolder = path.join(trainingRoot, job.name);
|
| 22 |
+
|
| 23 |
+
if (fs.existsSync(trainingFolder)) {
|
| 24 |
+
fs.rmdirSync(trainingFolder, { recursive: true });
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
await prisma.job.delete({
|
| 28 |
+
where: { id: jobID },
|
| 29 |
+
});
|
| 30 |
+
|
| 31 |
+
return NextResponse.json(job);
|
| 32 |
+
}
|
src/app/api/jobs/[jobID]/files/route.ts
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
import path from 'path';
|
| 4 |
+
import fs from 'fs';
|
| 5 |
+
import { getTrainingFolder } from '@/server/settings';
|
| 6 |
+
|
| 7 |
+
const prisma = new PrismaClient();
|
| 8 |
+
|
| 9 |
+
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
| 10 |
+
const { jobID } = await params;
|
| 11 |
+
|
| 12 |
+
const job = await prisma.job.findUnique({
|
| 13 |
+
where: { id: jobID },
|
| 14 |
+
});
|
| 15 |
+
|
| 16 |
+
if (!job) {
|
| 17 |
+
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
const trainingFolder = await getTrainingFolder();
|
| 21 |
+
const jobFolder = path.join(trainingFolder, job.name);
|
| 22 |
+
|
| 23 |
+
if (!fs.existsSync(jobFolder)) {
|
| 24 |
+
return NextResponse.json({ files: [] });
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// find all safetensors files in the job folder
|
| 28 |
+
let files = fs
|
| 29 |
+
.readdirSync(jobFolder)
|
| 30 |
+
.filter(file => {
|
| 31 |
+
return file.endsWith('.safetensors');
|
| 32 |
+
})
|
| 33 |
+
.map(file => {
|
| 34 |
+
return path.join(jobFolder, file);
|
| 35 |
+
})
|
| 36 |
+
.sort();
|
| 37 |
+
|
| 38 |
+
// get the file size for each file
|
| 39 |
+
const fileObjects = files.map(file => {
|
| 40 |
+
const stats = fs.statSync(file);
|
| 41 |
+
return {
|
| 42 |
+
path: file,
|
| 43 |
+
size: stats.size,
|
| 44 |
+
};
|
| 45 |
+
});
|
| 46 |
+
|
| 47 |
+
return NextResponse.json({ files: fileObjects });
|
| 48 |
+
}
|
src/app/api/jobs/[jobID]/log/route.ts
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
import path from 'path';
|
| 4 |
+
import fs from 'fs';
|
| 5 |
+
import { getTrainingFolder } from '@/server/settings';
|
| 6 |
+
|
| 7 |
+
const prisma = new PrismaClient();
|
| 8 |
+
|
| 9 |
+
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
| 10 |
+
const { jobID } = await params;
|
| 11 |
+
|
| 12 |
+
const job = await prisma.job.findUnique({
|
| 13 |
+
where: { id: jobID },
|
| 14 |
+
});
|
| 15 |
+
|
| 16 |
+
if (!job) {
|
| 17 |
+
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
const trainingFolder = await getTrainingFolder();
|
| 21 |
+
const jobFolder = path.join(trainingFolder, job.name);
|
| 22 |
+
const logPath = path.join(jobFolder, 'log.txt');
|
| 23 |
+
|
| 24 |
+
if (!fs.existsSync(logPath)) {
|
| 25 |
+
return NextResponse.json({ log: '' });
|
| 26 |
+
}
|
| 27 |
+
let log = '';
|
| 28 |
+
try {
|
| 29 |
+
log = fs.readFileSync(logPath, 'utf-8');
|
| 30 |
+
} catch (error) {
|
| 31 |
+
console.error('Error reading log file:', error);
|
| 32 |
+
log = 'Error reading log file';
|
| 33 |
+
}
|
| 34 |
+
return NextResponse.json({ log: log });
|
| 35 |
+
}
|
src/app/api/jobs/[jobID]/samples/route.ts
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
import path from 'path';
|
| 4 |
+
import fs from 'fs';
|
| 5 |
+
import { getTrainingFolder } from '@/server/settings';
|
| 6 |
+
|
| 7 |
+
const prisma = new PrismaClient();
|
| 8 |
+
|
| 9 |
+
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
| 10 |
+
const { jobID } = await params;
|
| 11 |
+
|
| 12 |
+
const job = await prisma.job.findUnique({
|
| 13 |
+
where: { id: jobID },
|
| 14 |
+
});
|
| 15 |
+
|
| 16 |
+
if (!job) {
|
| 17 |
+
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// setup the training
|
| 21 |
+
const trainingFolder = await getTrainingFolder();
|
| 22 |
+
|
| 23 |
+
const samplesFolder = path.join(trainingFolder, job.name, 'samples');
|
| 24 |
+
if (!fs.existsSync(samplesFolder)) {
|
| 25 |
+
return NextResponse.json({ samples: [] });
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// find all img (png, jpg, jpeg) files in the samples folder
|
| 29 |
+
const samples = fs
|
| 30 |
+
.readdirSync(samplesFolder)
|
| 31 |
+
.filter(file => {
|
| 32 |
+
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp');
|
| 33 |
+
})
|
| 34 |
+
.map(file => {
|
| 35 |
+
return path.join(samplesFolder, file);
|
| 36 |
+
})
|
| 37 |
+
.sort();
|
| 38 |
+
|
| 39 |
+
return NextResponse.json({ samples });
|
| 40 |
+
}
|
src/app/api/jobs/[jobID]/start/route.ts
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
import { TOOLKIT_ROOT } from '@/paths';
|
| 4 |
+
import { spawn } from 'child_process';
|
| 5 |
+
import path from 'path';
|
| 6 |
+
import fs from 'fs';
|
| 7 |
+
import os from 'os';
|
| 8 |
+
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
| 9 |
+
const isWindows = process.platform === 'win32';
|
| 10 |
+
|
| 11 |
+
const prisma = new PrismaClient();
|
| 12 |
+
|
| 13 |
+
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
| 14 |
+
const { jobID } = await params;
|
| 15 |
+
|
| 16 |
+
const job = await prisma.job.findUnique({
|
| 17 |
+
where: { id: jobID },
|
| 18 |
+
});
|
| 19 |
+
|
| 20 |
+
if (!job) {
|
| 21 |
+
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// update job status to 'running'
|
| 25 |
+
await prisma.job.update({
|
| 26 |
+
where: { id: jobID },
|
| 27 |
+
data: {
|
| 28 |
+
status: 'running',
|
| 29 |
+
stop: false,
|
| 30 |
+
info: 'Starting job...',
|
| 31 |
+
},
|
| 32 |
+
});
|
| 33 |
+
|
| 34 |
+
// setup the training
|
| 35 |
+
const trainingRoot = await getTrainingFolder();
|
| 36 |
+
|
| 37 |
+
const trainingFolder = path.join(trainingRoot, job.name);
|
| 38 |
+
if (!fs.existsSync(trainingFolder)) {
|
| 39 |
+
fs.mkdirSync(trainingFolder, { recursive: true });
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// make the config file
|
| 43 |
+
const configPath = path.join(trainingFolder, '.job_config.json');
|
| 44 |
+
|
| 45 |
+
//log to path
|
| 46 |
+
const logPath = path.join(trainingFolder, 'log.txt');
|
| 47 |
+
|
| 48 |
+
try {
|
| 49 |
+
// if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num
|
| 50 |
+
// if the log path does not exist, create it
|
| 51 |
+
if (fs.existsSync(logPath)) {
|
| 52 |
+
const logsFolder = path.join(trainingFolder, 'logs');
|
| 53 |
+
if (!fs.existsSync(logsFolder)) {
|
| 54 |
+
fs.mkdirSync(logsFolder, { recursive: true });
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
let num = 0;
|
| 58 |
+
while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) {
|
| 59 |
+
num++;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`));
|
| 63 |
+
}
|
| 64 |
+
} catch (e) {
|
| 65 |
+
console.error('Error moving log file:', e);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// update the config dataset path
|
| 69 |
+
const jobConfig = JSON.parse(job.job_config);
|
| 70 |
+
jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
|
| 71 |
+
|
| 72 |
+
// write the config file
|
| 73 |
+
fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
|
| 74 |
+
|
| 75 |
+
let pythonPath = 'python';
|
| 76 |
+
// use .venv or venv if it exists
|
| 77 |
+
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
| 78 |
+
if (isWindows) {
|
| 79 |
+
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
|
| 80 |
+
} else {
|
| 81 |
+
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
| 82 |
+
}
|
| 83 |
+
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
| 84 |
+
if (isWindows) {
|
| 85 |
+
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
|
| 86 |
+
} else {
|
| 87 |
+
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
|
| 92 |
+
if (!fs.existsSync(runFilePath)) {
|
| 93 |
+
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
const additionalEnv: any = {
|
| 97 |
+
AITK_JOB_ID: jobID,
|
| 98 |
+
CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
|
| 99 |
+
IS_AI_TOOLKIT_UI: '1'
|
| 100 |
+
};
|
| 101 |
+
|
| 102 |
+
// HF_TOKEN
|
| 103 |
+
const hfToken = await getHFToken();
|
| 104 |
+
if (hfToken && hfToken.trim() !== '') {
|
| 105 |
+
additionalEnv.HF_TOKEN = hfToken;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Add the --log argument to the command
|
| 109 |
+
const args = [runFilePath, configPath, '--log', logPath];
|
| 110 |
+
|
| 111 |
+
try {
|
| 112 |
+
let subprocess;
|
| 113 |
+
|
| 114 |
+
if (isWindows) {
|
| 115 |
+
// For Windows, use 'cmd.exe' to open a new command window
|
| 116 |
+
subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, ...args], {
|
| 117 |
+
env: {
|
| 118 |
+
...process.env,
|
| 119 |
+
...additionalEnv,
|
| 120 |
+
},
|
| 121 |
+
cwd: TOOLKIT_ROOT,
|
| 122 |
+
windowsHide: false,
|
| 123 |
+
});
|
| 124 |
+
} else {
|
| 125 |
+
// For non-Windows platforms
|
| 126 |
+
subprocess = spawn(pythonPath, args, {
|
| 127 |
+
detached: true,
|
| 128 |
+
stdio: ['ignore', 'pipe', 'pipe'], // Changed from 'ignore' to capture output
|
| 129 |
+
env: {
|
| 130 |
+
...process.env,
|
| 131 |
+
...additionalEnv,
|
| 132 |
+
},
|
| 133 |
+
cwd: TOOLKIT_ROOT,
|
| 134 |
+
});
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
// Start monitoring in the background without blocking the response
|
| 138 |
+
const monitorProcess = async () => {
|
| 139 |
+
const startTime = Date.now();
|
| 140 |
+
let errorOutput = '';
|
| 141 |
+
let stdoutput = '';
|
| 142 |
+
|
| 143 |
+
if (subprocess.stderr) {
|
| 144 |
+
subprocess.stderr.on('data', data => {
|
| 145 |
+
errorOutput += data.toString();
|
| 146 |
+
});
|
| 147 |
+
subprocess.stdout.on('data', data => {
|
| 148 |
+
stdoutput += data.toString();
|
| 149 |
+
// truncate to only get the last 500 characters
|
| 150 |
+
if (stdoutput.length > 500) {
|
| 151 |
+
stdoutput = stdoutput.substring(stdoutput.length - 500);
|
| 152 |
+
}
|
| 153 |
+
});
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
subprocess.on('exit', async code => {
|
| 157 |
+
const currentTime = Date.now();
|
| 158 |
+
const duration = (currentTime - startTime) / 1000;
|
| 159 |
+
console.log(`Job ${jobID} exited with code ${code} after ${duration} seconds.`);
|
| 160 |
+
// wait for 5 seconds to give it time to stop itself. It id still has a status of running in the db, update it to stopped
|
| 161 |
+
await new Promise(resolve => setTimeout(resolve, 5000));
|
| 162 |
+
const updatedJob = await prisma.job.findUnique({
|
| 163 |
+
where: { id: jobID },
|
| 164 |
+
});
|
| 165 |
+
if (updatedJob?.status === 'running') {
|
| 166 |
+
let errorString = errorOutput;
|
| 167 |
+
if (errorString.trim() === '') {
|
| 168 |
+
errorString = stdoutput;
|
| 169 |
+
}
|
| 170 |
+
await prisma.job.update({
|
| 171 |
+
where: { id: jobID },
|
| 172 |
+
data: {
|
| 173 |
+
status: 'error',
|
| 174 |
+
info: `Error launching job: ${errorString.substring(0, 500)}`,
|
| 175 |
+
},
|
| 176 |
+
});
|
| 177 |
+
}
|
| 178 |
+
});
|
| 179 |
+
|
| 180 |
+
// Wait 30 seconds before releasing the process
|
| 181 |
+
await new Promise(resolve => setTimeout(resolve, 30000));
|
| 182 |
+
// Detach the process for non-Windows systems
|
| 183 |
+
if (!isWindows && subprocess.unref) {
|
| 184 |
+
subprocess.unref();
|
| 185 |
+
}
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
// Start the monitoring without awaiting it
|
| 189 |
+
monitorProcess().catch(err => {
|
| 190 |
+
console.error(`Error in process monitoring for job ${jobID}:`, err);
|
| 191 |
+
});
|
| 192 |
+
|
| 193 |
+
// Return the response immediately
|
| 194 |
+
return NextResponse.json(job);
|
| 195 |
+
} catch (error: any) {
|
| 196 |
+
// Handle any exceptions during process launch
|
| 197 |
+
console.error('Error launching process:', error);
|
| 198 |
+
|
| 199 |
+
await prisma.job.update({
|
| 200 |
+
where: { id: jobID },
|
| 201 |
+
data: {
|
| 202 |
+
status: 'error',
|
| 203 |
+
info: `Error launching job: ${error?.message || 'Unknown error'}`,
|
| 204 |
+
},
|
| 205 |
+
});
|
| 206 |
+
|
| 207 |
+
return NextResponse.json(
|
| 208 |
+
{
|
| 209 |
+
error: 'Failed to launch job process',
|
| 210 |
+
details: error?.message || 'Unknown error',
|
| 211 |
+
},
|
| 212 |
+
{ status: 500 },
|
| 213 |
+
);
|
| 214 |
+
}
|
| 215 |
+
}
|
src/app/api/jobs/[jobID]/stop/route.ts
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
|
| 4 |
+
const prisma = new PrismaClient();
|
| 5 |
+
|
| 6 |
+
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
| 7 |
+
const { jobID } = await params;
|
| 8 |
+
|
| 9 |
+
const job = await prisma.job.findUnique({
|
| 10 |
+
where: { id: jobID },
|
| 11 |
+
});
|
| 12 |
+
|
| 13 |
+
// update job status to 'running'
|
| 14 |
+
await prisma.job.update({
|
| 15 |
+
where: { id: jobID },
|
| 16 |
+
data: {
|
| 17 |
+
stop: true,
|
| 18 |
+
info: 'Stopping job...',
|
| 19 |
+
},
|
| 20 |
+
});
|
| 21 |
+
|
| 22 |
+
return NextResponse.json(job);
|
| 23 |
+
}
|
src/app/api/jobs/route.ts
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
|
| 4 |
+
const prisma = new PrismaClient();
|
| 5 |
+
|
| 6 |
+
export async function GET(request: Request) {
|
| 7 |
+
const { searchParams } = new URL(request.url);
|
| 8 |
+
const id = searchParams.get('id');
|
| 9 |
+
|
| 10 |
+
try {
|
| 11 |
+
if (id) {
|
| 12 |
+
const job = await prisma.job.findUnique({
|
| 13 |
+
where: { id },
|
| 14 |
+
});
|
| 15 |
+
return NextResponse.json(job);
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
const jobs = await prisma.job.findMany({
|
| 19 |
+
orderBy: { created_at: 'desc' },
|
| 20 |
+
});
|
| 21 |
+
return NextResponse.json({ jobs: jobs });
|
| 22 |
+
} catch (error) {
|
| 23 |
+
console.error(error);
|
| 24 |
+
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
export async function POST(request: Request) {
|
| 29 |
+
try {
|
| 30 |
+
const body = await request.json();
|
| 31 |
+
const { id, name, job_config, gpu_ids } = body;
|
| 32 |
+
|
| 33 |
+
// Ensure gpu_ids is never null/undefined - provide default value
|
| 34 |
+
const safeGpuIds = gpu_ids || '0';
|
| 35 |
+
|
| 36 |
+
if (id) {
|
| 37 |
+
// Update existing training
|
| 38 |
+
const training = await prisma.job.update({
|
| 39 |
+
where: { id },
|
| 40 |
+
data: {
|
| 41 |
+
name,
|
| 42 |
+
gpu_ids: safeGpuIds,
|
| 43 |
+
job_config: JSON.stringify(job_config),
|
| 44 |
+
},
|
| 45 |
+
});
|
| 46 |
+
return NextResponse.json(training);
|
| 47 |
+
} else {
|
| 48 |
+
// Create new training
|
| 49 |
+
const training = await prisma.job.create({
|
| 50 |
+
data: {
|
| 51 |
+
name,
|
| 52 |
+
gpu_ids: safeGpuIds,
|
| 53 |
+
job_config: JSON.stringify(job_config),
|
| 54 |
+
},
|
| 55 |
+
});
|
| 56 |
+
return NextResponse.json(training);
|
| 57 |
+
}
|
| 58 |
+
} catch (error: any) {
|
| 59 |
+
if (error.code === 'P2002') {
|
| 60 |
+
// Handle unique constraint violation, 409=Conflict
|
| 61 |
+
return NextResponse.json({ error: 'Job name already exists' }, { status: 409 });
|
| 62 |
+
}
|
| 63 |
+
console.error(error);
|
| 64 |
+
// Handle other errors
|
| 65 |
+
return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 });
|
| 66 |
+
}
|
| 67 |
+
}
|
src/app/api/settings/route.ts
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextResponse } from 'next/server';
|
| 2 |
+
import { PrismaClient } from '@prisma/client';
|
| 3 |
+
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
|
| 4 |
+
import { flushCache } from '@/server/settings';
|
| 5 |
+
|
| 6 |
+
const prisma = new PrismaClient();
|
| 7 |
+
|
| 8 |
+
export async function GET() {
|
| 9 |
+
try {
|
| 10 |
+
const settings = await prisma.settings.findMany();
|
| 11 |
+
const settingsObject = settings.reduce((acc: any, setting) => {
|
| 12 |
+
acc[setting.key] = setting.value;
|
| 13 |
+
return acc;
|
| 14 |
+
}, {});
|
| 15 |
+
// if TRAINING_FOLDER is not set, use default
|
| 16 |
+
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
|
| 17 |
+
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
|
| 18 |
+
}
|
| 19 |
+
// if DATASETS_FOLDER is not set, use default
|
| 20 |
+
if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') {
|
| 21 |
+
settingsObject.DATASETS_FOLDER = defaultDatasetsFolder;
|
| 22 |
+
}
|
| 23 |
+
return NextResponse.json(settingsObject);
|
| 24 |
+
} catch (error) {
|
| 25 |
+
return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
export async function POST(request: Request) {
|
| 30 |
+
try {
|
| 31 |
+
const body = await request.json();
|
| 32 |
+
const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body;
|
| 33 |
+
|
| 34 |
+
// Upsert both settings
|
| 35 |
+
await Promise.all([
|
| 36 |
+
prisma.settings.upsert({
|
| 37 |
+
where: { key: 'HF_TOKEN' },
|
| 38 |
+
update: { value: HF_TOKEN },
|
| 39 |
+
create: { key: 'HF_TOKEN', value: HF_TOKEN },
|
| 40 |
+
}),
|
| 41 |
+
prisma.settings.upsert({
|
| 42 |
+
where: { key: 'TRAINING_FOLDER' },
|
| 43 |
+
update: { value: TRAINING_FOLDER },
|
| 44 |
+
create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER },
|
| 45 |
+
}),
|
| 46 |
+
prisma.settings.upsert({
|
| 47 |
+
where: { key: 'DATASETS_FOLDER' },
|
| 48 |
+
update: { value: DATASETS_FOLDER },
|
| 49 |
+
create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER },
|
| 50 |
+
}),
|
| 51 |
+
]);
|
| 52 |
+
|
| 53 |
+
flushCache();
|
| 54 |
+
|
| 55 |
+
return NextResponse.json({ success: true });
|
| 56 |
+
} catch (error) {
|
| 57 |
+
return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });
|
| 58 |
+
}
|
| 59 |
+
}
|
src/app/api/zip/route.ts
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* eslint-disable */
|
| 2 |
+
import { NextRequest, NextResponse } from 'next/server';
|
| 3 |
+
import fs from 'fs';
|
| 4 |
+
import fsp from 'fs/promises';
|
| 5 |
+
import path from 'path';
|
| 6 |
+
import archiver from 'archiver';
|
| 7 |
+
import { getTrainingFolder } from '@/server/settings';
|
| 8 |
+
|
| 9 |
+
export const runtime = 'nodejs'; // ensure Node APIs are available
|
| 10 |
+
export const dynamic = 'force-dynamic'; // long-running, non-cached
|
| 11 |
+
|
| 12 |
+
type PostBody = {
|
| 13 |
+
zipTarget: 'samples'; //only samples for now
|
| 14 |
+
jobName: string;
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
async function resolveSafe(p: string) {
|
| 18 |
+
// resolve symlinks + normalize
|
| 19 |
+
return await fsp.realpath(p);
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
export async function POST(request: NextRequest) {
|
| 23 |
+
try {
|
| 24 |
+
const body = (await request.json()) as PostBody;
|
| 25 |
+
if (!body || !body.jobName) {
|
| 26 |
+
return NextResponse.json({ error: 'jobName is required' }, { status: 400 });
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
const trainingRoot = await resolveSafe(await getTrainingFolder());
|
| 30 |
+
const folderPath = await resolveSafe(path.join(trainingRoot, body.jobName, 'samples'));
|
| 31 |
+
const outputPath = path.resolve(trainingRoot, body.jobName, 'samples.zip');
|
| 32 |
+
|
| 33 |
+
// Must be a directory
|
| 34 |
+
let stat: fs.Stats;
|
| 35 |
+
try {
|
| 36 |
+
stat = await fsp.stat(folderPath);
|
| 37 |
+
} catch {
|
| 38 |
+
return new NextResponse('Folder not found', { status: 404 });
|
| 39 |
+
}
|
| 40 |
+
if (!stat.isDirectory()) {
|
| 41 |
+
return new NextResponse('Not a directory', { status: 400 });
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// delete current one if it exists
|
| 45 |
+
if (fs.existsSync(outputPath)) {
|
| 46 |
+
await fsp.unlink(outputPath);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// Create write stream & archive
|
| 50 |
+
await new Promise<void>((resolve, reject) => {
|
| 51 |
+
const output = fs.createWriteStream(outputPath);
|
| 52 |
+
const archive = archiver('zip', { zlib: { level: 9 } });
|
| 53 |
+
|
| 54 |
+
output.on('close', () => resolve());
|
| 55 |
+
output.on('error', reject);
|
| 56 |
+
archive.on('error', reject);
|
| 57 |
+
|
| 58 |
+
archive.pipe(output);
|
| 59 |
+
|
| 60 |
+
// Add the directory contents (place them under the folder's base name in the zip)
|
| 61 |
+
const rootName = path.basename(folderPath);
|
| 62 |
+
archive.directory(folderPath, rootName);
|
| 63 |
+
|
| 64 |
+
archive.finalize().catch(reject);
|
| 65 |
+
});
|
| 66 |
+
|
| 67 |
+
// Return the absolute path so your existing /api/files/[...filePath] can serve it
|
| 68 |
+
// Example download URL (client-side): `/api/files/${encodeURIComponent(resolvedOutPath)}`
|
| 69 |
+
return NextResponse.json({
|
| 70 |
+
ok: true,
|
| 71 |
+
zipPath: outputPath,
|
| 72 |
+
fileName: path.basename(outputPath),
|
| 73 |
+
});
|
| 74 |
+
} catch (err) {
|
| 75 |
+
console.error('Zip error:', err);
|
| 76 |
+
return new NextResponse('Internal Server Error', { status: 500 });
|
| 77 |
+
}
|
| 78 |
+
}
|
src/app/apple-icon.png
ADDED
|
|
src/app/dashboard/page.tsx
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import JobsTable from '@/components/JobsTable';
|
| 4 |
+
import { TopBar, MainContent } from '@/components/layout';
|
| 5 |
+
import Link from 'next/link';
|
| 6 |
+
import { useAuth } from '@/contexts/AuthContext';
|
| 7 |
+
import HFLoginButton from '@/components/HFLoginButton';
|
| 8 |
+
|
| 9 |
+
export default function Dashboard() {
|
| 10 |
+
const { status: authStatus, namespace } = useAuth();
|
| 11 |
+
const isAuthenticated = authStatus === 'authenticated';
|
| 12 |
+
|
| 13 |
+
return (
|
| 14 |
+
<>
|
| 15 |
+
<TopBar>
|
| 16 |
+
<div>
|
| 17 |
+
<h1 className="text-lg">Dashboard</h1>
|
| 18 |
+
</div>
|
| 19 |
+
<div className="flex-1" />
|
| 20 |
+
</TopBar>
|
| 21 |
+
<MainContent>
|
| 22 |
+
<div className="border border-gray-800 rounded-xl bg-gray-900 p-6 flex flex-col gap-4">
|
| 23 |
+
<div>
|
| 24 |
+
<h2 className="text-xl font-semibold text-gray-100">
|
| 25 |
+
{isAuthenticated ? `Welcome back, ${namespace || 'creator'}!` : 'Welcome to Ostris AI Toolkit'}
|
| 26 |
+
</h2>
|
| 27 |
+
<p className="text-sm text-gray-400 mt-2">
|
| 28 |
+
{isAuthenticated
|
| 29 |
+
? 'You are signed in with Hugging Face and can manage jobs, datasets, and submissions.'
|
| 30 |
+
: 'Authenticate with Hugging Face or add a personal access token to create jobs, upload datasets, and launch training.'}
|
| 31 |
+
</p>
|
| 32 |
+
</div>
|
| 33 |
+
{isAuthenticated ? (
|
| 34 |
+
<div className="flex flex-wrap items-center gap-3 text-sm">
|
| 35 |
+
<Link
|
| 36 |
+
href="/jobs/new"
|
| 37 |
+
className="px-4 py-2 rounded-md bg-blue-600 hover:bg-blue-500 text-white transition-colors"
|
| 38 |
+
>
|
| 39 |
+
Create a Training Job
|
| 40 |
+
</Link>
|
| 41 |
+
<Link
|
| 42 |
+
href="/datasets"
|
| 43 |
+
className="px-4 py-2 rounded-md bg-gray-800 hover:bg-gray-700 text-gray-200 transition-colors"
|
| 44 |
+
>
|
| 45 |
+
Manage Datasets
|
| 46 |
+
</Link>
|
| 47 |
+
<Link
|
| 48 |
+
href="/settings"
|
| 49 |
+
className="px-4 py-2 rounded-md border border-gray-700 text-gray-300 hover:border-gray-600 transition-colors"
|
| 50 |
+
>
|
| 51 |
+
Settings
|
| 52 |
+
</Link>
|
| 53 |
+
</div>
|
| 54 |
+
) : (
|
| 55 |
+
<div className="flex flex-wrap items-center gap-3 text-sm">
|
| 56 |
+
<HFLoginButton size="md" />
|
| 57 |
+
<Link
|
| 58 |
+
href="/settings"
|
| 59 |
+
className="text-xs text-blue-400 hover:text-blue-300"
|
| 60 |
+
>
|
| 61 |
+
Or manage tokens in Settings
|
| 62 |
+
</Link>
|
| 63 |
+
</div>
|
| 64 |
+
)}
|
| 65 |
+
</div>
|
| 66 |
+
|
| 67 |
+
<div className="w-full mt-6">
|
| 68 |
+
<div className="flex justify-between items-center mb-2">
|
| 69 |
+
<h1 className="text-md">Active Jobs</h1>
|
| 70 |
+
<div className="text-xs text-gray-500">
|
| 71 |
+
<Link href="/jobs">View All</Link>
|
| 72 |
+
</div>
|
| 73 |
+
</div>
|
| 74 |
+
{isAuthenticated ? (
|
| 75 |
+
<JobsTable onlyActive />
|
| 76 |
+
) : (
|
| 77 |
+
<div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm">
|
| 78 |
+
Sign in with Hugging Face or add an access token in Settings to view and manage jobs.
|
| 79 |
+
</div>
|
| 80 |
+
)}
|
| 81 |
+
</div>
|
| 82 |
+
</MainContent>
|
| 83 |
+
</>
|
| 84 |
+
);
|
| 85 |
+
}
|
src/app/datasets/[datasetName]/page.tsx
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import { useEffect, useState, use, useMemo } from 'react';
|
| 4 |
+
import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu';
|
| 5 |
+
import { FaChevronLeft } from 'react-icons/fa';
|
| 6 |
+
import DatasetImageCard from '@/components/DatasetImageCard';
|
| 7 |
+
import { Button } from '@headlessui/react';
|
| 8 |
+
import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
|
| 9 |
+
import { TopBar, MainContent } from '@/components/layout';
|
| 10 |
+
import { apiClient } from '@/utils/api';
|
| 11 |
+
import FullscreenDropOverlay from '@/components/FullscreenDropOverlay';
|
| 12 |
+
import { useRouter } from 'next/navigation';
|
| 13 |
+
import { usingBrowserDb } from '@/utils/env';
|
| 14 |
+
import { hasUserDataset } from '@/utils/storage/datasetStorage';
|
| 15 |
+
import { useAuth } from '@/contexts/AuthContext';
|
| 16 |
+
import HFLoginButton from '@/components/HFLoginButton';
|
| 17 |
+
import Link from 'next/link';
|
| 18 |
+
|
| 19 |
+
export default function DatasetPage({ params }: { params: { datasetName: string } }) {
|
| 20 |
+
const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
|
| 21 |
+
const usableParams = use(params as any) as { datasetName: string };
|
| 22 |
+
const datasetName = usableParams.datasetName;
|
| 23 |
+
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
| 24 |
+
const router = useRouter();
|
| 25 |
+
const { status: authStatus } = useAuth();
|
| 26 |
+
const isAuthenticated = authStatus === 'authenticated';
|
| 27 |
+
const hasDatasetEntry = !usingBrowserDb || hasUserDataset(datasetName);
|
| 28 |
+
const allowAccess = hasDatasetEntry && isAuthenticated;
|
| 29 |
+
|
| 30 |
+
const refreshImageList = (dbName: string) => {
|
| 31 |
+
setStatus('loading');
|
| 32 |
+
console.log('Fetching images for dataset:', dbName);
|
| 33 |
+
apiClient
|
| 34 |
+
.post('/api/datasets/listImages', { datasetName: dbName })
|
| 35 |
+
.then((res: any) => {
|
| 36 |
+
const data = res.data;
|
| 37 |
+
console.log('Images:', data.images);
|
| 38 |
+
// sort
|
| 39 |
+
data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
|
| 40 |
+
setImgList(data.images);
|
| 41 |
+
setStatus('success');
|
| 42 |
+
})
|
| 43 |
+
.catch(error => {
|
| 44 |
+
console.error('Error fetching images:', error);
|
| 45 |
+
setStatus('error');
|
| 46 |
+
});
|
| 47 |
+
};
|
| 48 |
+
useEffect(() => {
|
| 49 |
+
if (!datasetName) {
|
| 50 |
+
return;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
if (!isAuthenticated) {
|
| 54 |
+
return;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
if (!hasDatasetEntry) {
|
| 58 |
+
setImgList([]);
|
| 59 |
+
setStatus('error');
|
| 60 |
+
router.replace('/datasets');
|
| 61 |
+
return;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
refreshImageList(datasetName);
|
| 65 |
+
}, [datasetName, hasDatasetEntry, isAuthenticated, router]);
|
| 66 |
+
|
| 67 |
+
if (!allowAccess) {
|
| 68 |
+
return (
|
| 69 |
+
<>
|
| 70 |
+
<TopBar>
|
| 71 |
+
<div>
|
| 72 |
+
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
| 73 |
+
<FaChevronLeft />
|
| 74 |
+
</Button>
|
| 75 |
+
</div>
|
| 76 |
+
<div>
|
| 77 |
+
<h1 className="text-lg">Dataset: {datasetName}</h1>
|
| 78 |
+
</div>
|
| 79 |
+
<div className="flex-1"></div>
|
| 80 |
+
</TopBar>
|
| 81 |
+
<MainContent>
|
| 82 |
+
<div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
|
| 83 |
+
<p>You need to sign in with Hugging Face or provide a valid token to view this dataset.</p>
|
| 84 |
+
<div className="flex items-center gap-3">
|
| 85 |
+
<HFLoginButton size="sm" />
|
| 86 |
+
<Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
|
| 87 |
+
Manage authentication in Settings
|
| 88 |
+
</Link>
|
| 89 |
+
</div>
|
| 90 |
+
</div>
|
| 91 |
+
</MainContent>
|
| 92 |
+
</>
|
| 93 |
+
);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
const PageInfoContent = useMemo(() => {
|
| 97 |
+
let icon = null;
|
| 98 |
+
let text = '';
|
| 99 |
+
let subtitle = '';
|
| 100 |
+
let showIt = false;
|
| 101 |
+
let bgColor = '';
|
| 102 |
+
let textColor = '';
|
| 103 |
+
let iconColor = '';
|
| 104 |
+
|
| 105 |
+
if (status == 'loading') {
|
| 106 |
+
icon = <LuLoader className="animate-spin w-8 h-8" />;
|
| 107 |
+
text = 'Loading Images';
|
| 108 |
+
subtitle = 'Please wait while we fetch your dataset images...';
|
| 109 |
+
showIt = true;
|
| 110 |
+
bgColor = 'bg-gray-50 dark:bg-gray-800/50';
|
| 111 |
+
textColor = 'text-gray-900 dark:text-gray-100';
|
| 112 |
+
iconColor = 'text-gray-500 dark:text-gray-400';
|
| 113 |
+
}
|
| 114 |
+
if (status == 'error') {
|
| 115 |
+
icon = <LuBan className="w-8 h-8" />;
|
| 116 |
+
text = 'Error Loading Images';
|
| 117 |
+
subtitle = 'There was a problem fetching the images. Please try refreshing the page.';
|
| 118 |
+
showIt = true;
|
| 119 |
+
bgColor = 'bg-red-50 dark:bg-red-950/20';
|
| 120 |
+
textColor = 'text-red-900 dark:text-red-100';
|
| 121 |
+
iconColor = 'text-red-600 dark:text-red-400';
|
| 122 |
+
}
|
| 123 |
+
if (status == 'success' && imgList.length === 0) {
|
| 124 |
+
icon = <LuImageOff className="w-8 h-8" />;
|
| 125 |
+
text = 'No Images Found';
|
| 126 |
+
subtitle = 'This dataset is empty. Click "Add Images" to get started.';
|
| 127 |
+
showIt = true;
|
| 128 |
+
bgColor = 'bg-gray-50 dark:bg-gray-800/50';
|
| 129 |
+
textColor = 'text-gray-900 dark:text-gray-100';
|
| 130 |
+
iconColor = 'text-gray-500 dark:text-gray-400';
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
if (!showIt) return null;
|
| 134 |
+
|
| 135 |
+
return (
|
| 136 |
+
<div
|
| 137 |
+
className={`mt-10 flex flex-col items-center justify-center py-16 px-8 rounded-xl border-2 border-gray-700 border-dashed ${bgColor} ${textColor} mx-auto max-w-md text-center`}
|
| 138 |
+
>
|
| 139 |
+
<div className={`${iconColor} mb-4`}>{icon}</div>
|
| 140 |
+
<h3 className="text-lg font-semibold mb-2">{text}</h3>
|
| 141 |
+
<p className="text-sm opacity-75 leading-relaxed">{subtitle}</p>
|
| 142 |
+
</div>
|
| 143 |
+
);
|
| 144 |
+
}, [status, imgList.length]);
|
| 145 |
+
|
| 146 |
+
return (
|
| 147 |
+
<>
|
| 148 |
+
{/* Fixed top bar */}
|
| 149 |
+
<TopBar>
|
| 150 |
+
<div>
|
| 151 |
+
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
| 152 |
+
<FaChevronLeft />
|
| 153 |
+
</Button>
|
| 154 |
+
</div>
|
| 155 |
+
<div>
|
| 156 |
+
<h1 className="text-lg">Dataset: {datasetName}</h1>
|
| 157 |
+
</div>
|
| 158 |
+
<div className="flex-1"></div>
|
| 159 |
+
<div>
|
| 160 |
+
<Button
|
| 161 |
+
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
|
| 162 |
+
onClick={() => openImagesModal(datasetName, () => refreshImageList(datasetName))}
|
| 163 |
+
>
|
| 164 |
+
Add Images
|
| 165 |
+
</Button>
|
| 166 |
+
</div>
|
| 167 |
+
</TopBar>
|
| 168 |
+
<MainContent>
|
| 169 |
+
{PageInfoContent}
|
| 170 |
+
{status === 'success' && imgList.length > 0 && (
|
| 171 |
+
<div className="grid grid-cols-1 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-4">
|
| 172 |
+
{imgList.map(img => (
|
| 173 |
+
<DatasetImageCard
|
| 174 |
+
key={img.img_path}
|
| 175 |
+
alt="image"
|
| 176 |
+
imageUrl={img.img_path}
|
| 177 |
+
onDelete={() => refreshImageList(datasetName)}
|
| 178 |
+
/>
|
| 179 |
+
))}
|
| 180 |
+
</div>
|
| 181 |
+
)}
|
| 182 |
+
</MainContent>
|
| 183 |
+
<AddImagesModal />
|
| 184 |
+
<FullscreenDropOverlay
|
| 185 |
+
datasetName={datasetName}
|
| 186 |
+
onComplete={() => refreshImageList(datasetName)}
|
| 187 |
+
/>
|
| 188 |
+
</>
|
| 189 |
+
);
|
| 190 |
+
}
|
src/app/datasets/page.tsx
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import { useState } from 'react';
|
| 4 |
+
import { Modal } from '@/components/Modal';
|
| 5 |
+
import Link from 'next/link';
|
| 6 |
+
import { TextInput } from '@/components/formInputs';
|
| 7 |
+
import useDatasetList from '@/hooks/useDatasetList';
|
| 8 |
+
import { Button } from '@headlessui/react';
|
| 9 |
+
import { FaRegTrashAlt } from 'react-icons/fa';
|
| 10 |
+
import { openConfirm } from '@/components/ConfirmModal';
|
| 11 |
+
import { TopBar, MainContent } from '@/components/layout';
|
| 12 |
+
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
|
| 13 |
+
import { apiClient } from '@/utils/api';
|
| 14 |
+
import { useRouter } from 'next/navigation';
|
| 15 |
+
import { usingBrowserDb } from '@/utils/env';
|
| 16 |
+
import { addUserDataset, removeUserDataset } from '@/utils/storage/datasetStorage';
|
| 17 |
+
import { useAuth } from '@/contexts/AuthContext';
|
| 18 |
+
import HFLoginButton from '@/components/HFLoginButton';
|
| 19 |
+
|
| 20 |
+
export default function Datasets() {
|
| 21 |
+
const router = useRouter();
|
| 22 |
+
const { datasets, status, refreshDatasets } = useDatasetList();
|
| 23 |
+
const [newDatasetName, setNewDatasetName] = useState('');
|
| 24 |
+
const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false);
|
| 25 |
+
const { status: authStatus } = useAuth();
|
| 26 |
+
const isAuthenticated = authStatus === 'authenticated';
|
| 27 |
+
|
| 28 |
+
// Transform datasets array into rows with objects
|
| 29 |
+
const tableRows = datasets.map(dataset => ({
|
| 30 |
+
name: dataset,
|
| 31 |
+
actions: dataset, // Pass full dataset name for actions
|
| 32 |
+
}));
|
| 33 |
+
|
| 34 |
+
const columns: TableColumn[] = [
|
| 35 |
+
{
|
| 36 |
+
title: 'Dataset Name',
|
| 37 |
+
key: 'name',
|
| 38 |
+
render: row => (
|
| 39 |
+
<Link href={`/datasets/${row.name}`} className="text-gray-200 hover:text-gray-100">
|
| 40 |
+
{row.name}
|
| 41 |
+
</Link>
|
| 42 |
+
),
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
title: 'Actions',
|
| 46 |
+
key: 'actions',
|
| 47 |
+
className: 'w-20 text-right',
|
| 48 |
+
render: row => (
|
| 49 |
+
<button
|
| 50 |
+
className="text-gray-200 hover:bg-red-600 p-2 rounded-full transition-colors"
|
| 51 |
+
onClick={() => handleDeleteDataset(row.name)}
|
| 52 |
+
>
|
| 53 |
+
<FaRegTrashAlt />
|
| 54 |
+
</button>
|
| 55 |
+
),
|
| 56 |
+
},
|
| 57 |
+
];
|
| 58 |
+
|
| 59 |
+
const handleDeleteDataset = (datasetName: string) => {
|
| 60 |
+
openConfirm({
|
| 61 |
+
title: 'Delete Dataset',
|
| 62 |
+
message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`,
|
| 63 |
+
type: 'warning',
|
| 64 |
+
confirmText: 'Delete',
|
| 65 |
+
onConfirm: () => {
|
| 66 |
+
apiClient
|
| 67 |
+
.post('/api/datasets/delete', { name: datasetName })
|
| 68 |
+
.then(() => {
|
| 69 |
+
console.log('Dataset deleted:', datasetName);
|
| 70 |
+
if (usingBrowserDb) {
|
| 71 |
+
removeUserDataset(datasetName);
|
| 72 |
+
}
|
| 73 |
+
refreshDatasets();
|
| 74 |
+
})
|
| 75 |
+
.catch(error => {
|
| 76 |
+
console.error('Error deleting dataset:', error);
|
| 77 |
+
});
|
| 78 |
+
},
|
| 79 |
+
});
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
const handleCreateDataset = async (e: React.FormEvent) => {
|
| 83 |
+
e.preventDefault();
|
| 84 |
+
if (!isAuthenticated) {
|
| 85 |
+
return;
|
| 86 |
+
}
|
| 87 |
+
try {
|
| 88 |
+
const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data);
|
| 89 |
+
console.log('New dataset created:', data);
|
| 90 |
+
if (usingBrowserDb && data?.name) {
|
| 91 |
+
addUserDataset(data.name, data?.path || '');
|
| 92 |
+
}
|
| 93 |
+
refreshDatasets();
|
| 94 |
+
setNewDatasetName('');
|
| 95 |
+
setIsNewDatasetModalOpen(false);
|
| 96 |
+
} catch (error) {
|
| 97 |
+
console.error('Error creating new dataset:', error);
|
| 98 |
+
}
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
const openNewDatasetModal = () => {
|
| 102 |
+
if (!isAuthenticated) {
|
| 103 |
+
return;
|
| 104 |
+
}
|
| 105 |
+
openConfirm({
|
| 106 |
+
title: 'New Dataset',
|
| 107 |
+
message: 'Enter the name of the new dataset:',
|
| 108 |
+
type: 'info',
|
| 109 |
+
confirmText: 'Create',
|
| 110 |
+
inputTitle: 'Dataset Name',
|
| 111 |
+
onConfirm: async (name?: string) => {
|
| 112 |
+
if (!name) {
|
| 113 |
+
console.error('Dataset name is required.');
|
| 114 |
+
return;
|
| 115 |
+
}
|
| 116 |
+
if (!isAuthenticated) {
|
| 117 |
+
return;
|
| 118 |
+
}
|
| 119 |
+
try {
|
| 120 |
+
const data = await apiClient.post('/api/datasets/create', { name }).then(res => res.data);
|
| 121 |
+
console.log('New dataset created:', data);
|
| 122 |
+
if (usingBrowserDb && data?.name) {
|
| 123 |
+
addUserDataset(data.name, data?.path || '');
|
| 124 |
+
}
|
| 125 |
+
if (data.name) {
|
| 126 |
+
router.push(`/datasets/${data.name}`);
|
| 127 |
+
} else {
|
| 128 |
+
refreshDatasets();
|
| 129 |
+
}
|
| 130 |
+
} catch (error) {
|
| 131 |
+
console.error('Error creating new dataset:', error);
|
| 132 |
+
}
|
| 133 |
+
},
|
| 134 |
+
});
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
return (
|
| 138 |
+
<>
|
| 139 |
+
<TopBar>
|
| 140 |
+
<div>
|
| 141 |
+
<h1 className="text-2xl font-semibold text-gray-100">Datasets</h1>
|
| 142 |
+
</div>
|
| 143 |
+
<div className="flex-1"></div>
|
| 144 |
+
<div>
|
| 145 |
+
{isAuthenticated ? (
|
| 146 |
+
<Button
|
| 147 |
+
className="text-gray-200 bg-slate-600 px-4 py-2 rounded-md hover:bg-slate-500 transition-colors"
|
| 148 |
+
onClick={() => openNewDatasetModal()}
|
| 149 |
+
>
|
| 150 |
+
New Dataset
|
| 151 |
+
</Button>
|
| 152 |
+
) : (
|
| 153 |
+
<span className="text-gray-600 bg-gray-900 px-3 py-1 rounded-md border border-gray-800">
|
| 154 |
+
Sign in to add datasets
|
| 155 |
+
</span>
|
| 156 |
+
)}
|
| 157 |
+
</div>
|
| 158 |
+
</TopBar>
|
| 159 |
+
|
| 160 |
+
<MainContent>
|
| 161 |
+
{isAuthenticated ? (
|
| 162 |
+
<UniversalTable
|
| 163 |
+
columns={columns}
|
| 164 |
+
rows={tableRows}
|
| 165 |
+
isLoading={status === 'loading'}
|
| 166 |
+
onRefresh={refreshDatasets}
|
| 167 |
+
/>
|
| 168 |
+
) : (
|
| 169 |
+
<div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
|
| 170 |
+
<p>Sign in with Hugging Face or add an access token to manage datasets.</p>
|
| 171 |
+
<div className="flex items-center gap-3">
|
| 172 |
+
<HFLoginButton size="sm" />
|
| 173 |
+
<Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
|
| 174 |
+
Manage authentication in Settings
|
| 175 |
+
</Link>
|
| 176 |
+
</div>
|
| 177 |
+
</div>
|
| 178 |
+
)}
|
| 179 |
+
</MainContent>
|
| 180 |
+
|
| 181 |
+
<Modal
|
| 182 |
+
isOpen={isNewDatasetModalOpen}
|
| 183 |
+
onClose={() => setIsNewDatasetModalOpen(false)}
|
| 184 |
+
title="New Dataset"
|
| 185 |
+
size="md"
|
| 186 |
+
>
|
| 187 |
+
<div className="space-y-4 text-gray-200">
|
| 188 |
+
<form onSubmit={handleCreateDataset}>
|
| 189 |
+
<div className="text-sm text-gray-400">
|
| 190 |
+
This will create a new folder with the name below in your dataset folder.
|
| 191 |
+
</div>
|
| 192 |
+
<div className="mt-4">
|
| 193 |
+
<TextInput label="Dataset Name" value={newDatasetName} onChange={value => setNewDatasetName(value)} />
|
| 194 |
+
</div>
|
| 195 |
+
|
| 196 |
+
<div className="mt-6 flex justify-end space-x-3">
|
| 197 |
+
<button
|
| 198 |
+
type="button"
|
| 199 |
+
className="rounded-md bg-gray-700 px-4 py-2 text-gray-200 hover:bg-gray-600 focus:outline-none focus:ring-2 focus:ring-gray-500"
|
| 200 |
+
onClick={() => setIsNewDatasetModalOpen(false)}
|
| 201 |
+
>
|
| 202 |
+
Cancel
|
| 203 |
+
</button>
|
| 204 |
+
<button
|
| 205 |
+
type="submit"
|
| 206 |
+
className="rounded-md bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 disabled:opacity-50 disabled:cursor-not-allowed"
|
| 207 |
+
disabled={!isAuthenticated}
|
| 208 |
+
>
|
| 209 |
+
Confirm
|
| 210 |
+
</button>
|
| 211 |
+
</div>
|
| 212 |
+
</form>
|
| 213 |
+
</div>
|
| 214 |
+
</Modal>
|
| 215 |
+
</>
|
| 216 |
+
);
|
| 217 |
+
}
|
src/app/favicon.ico
ADDED
|
|
src/app/globals.css
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@tailwind base;
|
| 2 |
+
@tailwind components;
|
| 3 |
+
@tailwind utilities;
|
| 4 |
+
|
| 5 |
+
:root {
|
| 6 |
+
--background: #ffffff;
|
| 7 |
+
--foreground: #171717;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
@media (prefers-color-scheme: dark) {
|
| 11 |
+
:root {
|
| 12 |
+
--background: #0a0a0a;
|
| 13 |
+
--foreground: #ededed;
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
body {
|
| 18 |
+
color: var(--foreground);
|
| 19 |
+
background: var(--background);
|
| 20 |
+
font-family: Arial, Helvetica, sans-serif;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
@layer components {
|
| 24 |
+
/* control */
|
| 25 |
+
.aitk-react-select-container .aitk-react-select__control {
|
| 26 |
+
@apply flex w-full h-8 min-h-0 px-0 text-sm bg-gray-800 border border-gray-700 rounded-sm hover:border-gray-600 items-center;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
/* selected label */
|
| 30 |
+
.aitk-react-select-container .aitk-react-select__single-value {
|
| 31 |
+
@apply flex-1 min-w-0 truncate text-sm text-neutral-200;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
/* invisible input (keeps focus & typing, never wraps) */
|
| 35 |
+
.aitk-react-select-container .aitk-react-select__input-container {
|
| 36 |
+
@apply text-neutral-200;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
/* focus */
|
| 40 |
+
.aitk-react-select-container .aitk-react-select__control--is-focused {
|
| 41 |
+
@apply ring-2 ring-gray-600 border-transparent hover:border-transparent shadow-none;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
/* menu */
|
| 45 |
+
.aitk-react-select-container .aitk-react-select__menu {
|
| 46 |
+
@apply bg-gray-800 border border-gray-700;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
/* options */
|
| 50 |
+
.aitk-react-select-container .aitk-react-select__option {
|
| 51 |
+
@apply text-sm text-neutral-200 bg-gray-800 hover:bg-gray-700;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
/* indicator separator */
|
| 55 |
+
.aitk-react-select-container .aitk-react-select__indicator-separator {
|
| 56 |
+
@apply bg-gray-600;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/* indicators */
|
| 60 |
+
.aitk-react-select-container .aitk-react-select__indicators,
|
| 61 |
+
.aitk-react-select-container .aitk-react-select__indicator {
|
| 62 |
+
@apply py-0 flex items-center;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
/* placeholder */
|
| 66 |
+
.aitk-react-select-container .aitk-react-select__placeholder {
|
| 67 |
+
@apply text-sm text-neutral-200;
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
src/app/icon.png
ADDED
|
|
src/app/icon.svg
ADDED
|
|
src/app/jobs/[jobID]/page.tsx
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import { useState, use } from 'react';
|
| 4 |
+
import { FaChevronLeft } from 'react-icons/fa';
|
| 5 |
+
import { Button } from '@headlessui/react';
|
| 6 |
+
import { TopBar, MainContent } from '@/components/layout';
|
| 7 |
+
import useJob from '@/hooks/useJob';
|
| 8 |
+
import SampleImages, {SampleImagesMenu} from '@/components/SampleImages';
|
| 9 |
+
import JobOverview from '@/components/JobOverview';
|
| 10 |
+
import { redirect } from 'next/navigation';
|
| 11 |
+
import { useAuth } from '@/contexts/AuthContext';
|
| 12 |
+
import HFLoginButton from '@/components/HFLoginButton';
|
| 13 |
+
import Link from 'next/link';
|
| 14 |
+
import JobActionBar from '@/components/JobActionBar';
|
| 15 |
+
import JobConfigViewer from '@/components/JobConfigViewer';
|
| 16 |
+
import { JobRecord } from '@/types';
|
| 17 |
+
|
| 18 |
+
type PageKey = 'overview' | 'samples' | 'config';
|
| 19 |
+
|
| 20 |
+
interface Page {
|
| 21 |
+
name: string;
|
| 22 |
+
value: PageKey;
|
| 23 |
+
component: React.ComponentType<{ job: JobRecord }>;
|
| 24 |
+
menuItem?: React.ComponentType<{ job?: JobRecord | null }> | null;
|
| 25 |
+
mainCss?: string;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
const pages: Page[] = [
|
| 29 |
+
{
|
| 30 |
+
name: 'Overview',
|
| 31 |
+
value: 'overview',
|
| 32 |
+
component: JobOverview,
|
| 33 |
+
mainCss: 'pt-24',
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
name: 'Samples',
|
| 37 |
+
value: 'samples',
|
| 38 |
+
component: SampleImages,
|
| 39 |
+
menuItem: SampleImagesMenu,
|
| 40 |
+
mainCss: 'pt-24',
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
name: 'Config File',
|
| 44 |
+
value: 'config',
|
| 45 |
+
component: JobConfigViewer,
|
| 46 |
+
mainCss: 'pt-[80px] px-0 pb-0',
|
| 47 |
+
},
|
| 48 |
+
];
|
| 49 |
+
|
| 50 |
+
export default function JobPage({ params }: { params: { jobID: string } }) {
|
| 51 |
+
const usableParams = use(params as any) as { jobID: string };
|
| 52 |
+
const jobID = usableParams.jobID;
|
| 53 |
+
const { job, status, refreshJob } = useJob(jobID, 5000);
|
| 54 |
+
const [pageKey, setPageKey] = useState<PageKey>('overview');
|
| 55 |
+
const { status: authStatus } = useAuth();
|
| 56 |
+
const isAuthenticated = authStatus === 'authenticated';
|
| 57 |
+
|
| 58 |
+
const page = pages.find(p => p.value === pageKey);
|
| 59 |
+
|
| 60 |
+
if (!isAuthenticated) {
|
| 61 |
+
return (
|
| 62 |
+
<>
|
| 63 |
+
<TopBar>
|
| 64 |
+
<div>
|
| 65 |
+
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => redirect('/jobs')}>
|
| 66 |
+
<FaChevronLeft />
|
| 67 |
+
</Button>
|
| 68 |
+
</div>
|
| 69 |
+
<div>
|
| 70 |
+
<h1 className="text-lg">Job Details</h1>
|
| 71 |
+
</div>
|
| 72 |
+
<div className="flex-1"></div>
|
| 73 |
+
</TopBar>
|
| 74 |
+
<MainContent>
|
| 75 |
+
<div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
|
| 76 |
+
<p>Sign in with Hugging Face or add an access token to view job details.</p>
|
| 77 |
+
<div className="flex items-center gap-3">
|
| 78 |
+
<HFLoginButton size="sm" />
|
| 79 |
+
<Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
|
| 80 |
+
Manage authentication in Settings
|
| 81 |
+
</Link>
|
| 82 |
+
</div>
|
| 83 |
+
</div>
|
| 84 |
+
</MainContent>
|
| 85 |
+
</>
|
| 86 |
+
);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
return (
|
| 90 |
+
<>
|
| 91 |
+
{/* Fixed top bar */}
|
| 92 |
+
<TopBar>
|
| 93 |
+
<div>
|
| 94 |
+
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => redirect('/jobs')}>
|
| 95 |
+
<FaChevronLeft />
|
| 96 |
+
</Button>
|
| 97 |
+
</div>
|
| 98 |
+
<div>
|
| 99 |
+
<h1 className="text-lg">Job: {job?.name}</h1>
|
| 100 |
+
</div>
|
| 101 |
+
<div className="flex-1"></div>
|
| 102 |
+
{job && (
|
| 103 |
+
<JobActionBar
|
| 104 |
+
job={job}
|
| 105 |
+
onRefresh={refreshJob}
|
| 106 |
+
hideView
|
| 107 |
+
afterDelete={() => {
|
| 108 |
+
redirect('/jobs');
|
| 109 |
+
}}
|
| 110 |
+
/>
|
| 111 |
+
)}
|
| 112 |
+
</TopBar>
|
| 113 |
+
<MainContent className={pages.find(page => page.value === pageKey)?.mainCss}>
|
| 114 |
+
{status === 'loading' && job == null && <p>Loading...</p>}
|
| 115 |
+
{status === 'error' && job == null && <p>Error fetching job</p>}
|
| 116 |
+
{job && (
|
| 117 |
+
<>
|
| 118 |
+
{pages.map(page => {
|
| 119 |
+
const Component = page.component;
|
| 120 |
+
return page.value === pageKey ? <Component key={page.value} job={job} /> : null;
|
| 121 |
+
})}
|
| 122 |
+
</>
|
| 123 |
+
)}
|
| 124 |
+
</MainContent>
|
| 125 |
+
<div className="bg-gray-800 absolute top-12 left-0 w-full h-8 flex items-center px-2 text-sm">
|
| 126 |
+
{pages.map(page => (
|
| 127 |
+
<Button
|
| 128 |
+
key={page.value}
|
| 129 |
+
onClick={() => setPageKey(page.value)}
|
| 130 |
+
className={`px-4 py-1 h-8 ${page.value === pageKey ? 'bg-gray-300 dark:bg-gray-700' : ''}`}
|
| 131 |
+
>
|
| 132 |
+
{page.name}
|
| 133 |
+
</Button>
|
| 134 |
+
))}
|
| 135 |
+
{
|
| 136 |
+
page?.menuItem && (
|
| 137 |
+
<>
|
| 138 |
+
<div className='flex-grow'>
|
| 139 |
+
</div>
|
| 140 |
+
<page.menuItem job={job} />
|
| 141 |
+
</>
|
| 142 |
+
)
|
| 143 |
+
}
|
| 144 |
+
</div>
|
| 145 |
+
</>
|
| 146 |
+
);
|
| 147 |
+
}
|
src/app/jobs/new/AdvancedJob.tsx
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
import { useEffect, useState, useRef } from 'react';
|
| 3 |
+
import { JobConfig } from '@/types';
|
| 4 |
+
import YAML from 'yaml';
|
| 5 |
+
import Editor, { OnMount } from '@monaco-editor/react';
|
| 6 |
+
import type { editor } from 'monaco-editor';
|
| 7 |
+
import { SettingsData } from '@/types';
|
| 8 |
+
import { migrateJobConfig } from './jobConfig';
|
| 9 |
+
|
| 10 |
+
type Props = {
|
| 11 |
+
jobConfig: JobConfig;
|
| 12 |
+
setJobConfig: (value: any, key?: string) => void;
|
| 13 |
+
status: 'idle' | 'saving' | 'success' | 'error';
|
| 14 |
+
handleSubmit: (event: React.FormEvent<HTMLFormElement>) => void;
|
| 15 |
+
runId: string | null;
|
| 16 |
+
gpuIDs: string | null;
|
| 17 |
+
setGpuIDs: (value: string | null) => void;
|
| 18 |
+
gpuList: any;
|
| 19 |
+
datasetOptions: any;
|
| 20 |
+
settings: SettingsData;
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
const isDev = process.env.NODE_ENV === 'development';
|
| 24 |
+
|
| 25 |
+
const yamlConfig: YAML.DocumentOptions &
|
| 26 |
+
YAML.SchemaOptions &
|
| 27 |
+
YAML.ParseOptions &
|
| 28 |
+
YAML.CreateNodeOptions &
|
| 29 |
+
YAML.ToStringOptions = {
|
| 30 |
+
indent: 2,
|
| 31 |
+
lineWidth: 999999999999,
|
| 32 |
+
defaultStringType: 'QUOTE_DOUBLE',
|
| 33 |
+
defaultKeyType: 'PLAIN',
|
| 34 |
+
directives: true,
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) {
|
| 38 |
+
const [editorValue, setEditorValue] = useState<string>('');
|
| 39 |
+
const lastJobConfigUpdateStringRef = useRef('');
|
| 40 |
+
const editorRef = useRef<editor.IStandaloneCodeEditor | null>(null);
|
| 41 |
+
|
| 42 |
+
// Track if the editor has been mounted
|
| 43 |
+
const isEditorMounted = useRef(false);
|
| 44 |
+
|
| 45 |
+
// Handler for editor mounting
|
| 46 |
+
const handleEditorDidMount: OnMount = editor => {
|
| 47 |
+
editorRef.current = editor;
|
| 48 |
+
isEditorMounted.current = true;
|
| 49 |
+
|
| 50 |
+
// Initial content setup
|
| 51 |
+
try {
|
| 52 |
+
const yamlContent = YAML.stringify(jobConfig, yamlConfig);
|
| 53 |
+
setEditorValue(yamlContent);
|
| 54 |
+
lastJobConfigUpdateStringRef.current = JSON.stringify(jobConfig);
|
| 55 |
+
} catch (e) {
|
| 56 |
+
console.warn(e);
|
| 57 |
+
}
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
useEffect(() => {
|
| 61 |
+
const lastUpdate = lastJobConfigUpdateStringRef.current;
|
| 62 |
+
const currentUpdate = JSON.stringify(jobConfig);
|
| 63 |
+
|
| 64 |
+
// Skip if no changes or editor not yet mounted
|
| 65 |
+
if (lastUpdate === currentUpdate || !isEditorMounted.current) {
|
| 66 |
+
return;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
try {
|
| 70 |
+
// Preserve cursor position and selection
|
| 71 |
+
const editor = editorRef.current;
|
| 72 |
+
if (editor) {
|
| 73 |
+
// Save current editor state
|
| 74 |
+
const position = editor.getPosition();
|
| 75 |
+
const selection = editor.getSelection();
|
| 76 |
+
const scrollTop = editor.getScrollTop();
|
| 77 |
+
|
| 78 |
+
// Update content
|
| 79 |
+
const yamlContent = YAML.stringify(jobConfig, yamlConfig);
|
| 80 |
+
|
| 81 |
+
// Only update if the content is actually different
|
| 82 |
+
if (yamlContent !== editor.getValue()) {
|
| 83 |
+
// Set value directly on the editor model instead of using React state
|
| 84 |
+
editor.getModel()?.setValue(yamlContent);
|
| 85 |
+
|
| 86 |
+
// Restore cursor position and selection
|
| 87 |
+
if (position) editor.setPosition(position);
|
| 88 |
+
if (selection) editor.setSelection(selection);
|
| 89 |
+
editor.setScrollTop(scrollTop);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
lastJobConfigUpdateStringRef.current = currentUpdate;
|
| 93 |
+
}
|
| 94 |
+
} catch (e) {
|
| 95 |
+
console.warn(e);
|
| 96 |
+
}
|
| 97 |
+
}, [jobConfig]);
|
| 98 |
+
|
| 99 |
+
const handleChange = (value: string | undefined) => {
|
| 100 |
+
if (value === undefined) return;
|
| 101 |
+
|
| 102 |
+
try {
|
| 103 |
+
const parsed = YAML.parse(value);
|
| 104 |
+
// Don't update jobConfig if the change came from the editor itself
|
| 105 |
+
// to avoid a circular update loop
|
| 106 |
+
if (JSON.stringify(parsed) !== lastJobConfigUpdateStringRef.current) {
|
| 107 |
+
lastJobConfigUpdateStringRef.current = JSON.stringify(parsed);
|
| 108 |
+
|
| 109 |
+
// We have to ensure certain things are always set
|
| 110 |
+
try {
|
| 111 |
+
parsed.config.process[0].type = 'ui_trainer';
|
| 112 |
+
parsed.config.process[0].sqlite_db_path = './aitk_db.db';
|
| 113 |
+
parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
|
| 114 |
+
parsed.config.process[0].device = 'cuda';
|
| 115 |
+
parsed.config.process[0].performance_log_every = 10;
|
| 116 |
+
} catch (e) {
|
| 117 |
+
console.warn(e);
|
| 118 |
+
}
|
| 119 |
+
migrateJobConfig(parsed);
|
| 120 |
+
setJobConfig(parsed);
|
| 121 |
+
}
|
| 122 |
+
} catch (e) {
|
| 123 |
+
// Don't update on parsing errors
|
| 124 |
+
console.warn(e);
|
| 125 |
+
}
|
| 126 |
+
};
|
| 127 |
+
|
| 128 |
+
return (
|
| 129 |
+
<>
|
| 130 |
+
<Editor
|
| 131 |
+
height="100%"
|
| 132 |
+
width="100%"
|
| 133 |
+
defaultLanguage="yaml"
|
| 134 |
+
value={editorValue}
|
| 135 |
+
theme="vs-dark"
|
| 136 |
+
onChange={handleChange}
|
| 137 |
+
onMount={handleEditorDidMount}
|
| 138 |
+
options={{
|
| 139 |
+
minimap: { enabled: true },
|
| 140 |
+
scrollBeyondLastLine: false,
|
| 141 |
+
automaticLayout: true,
|
| 142 |
+
}}
|
| 143 |
+
/>
|
| 144 |
+
</>
|
| 145 |
+
);
|
| 146 |
+
}
|
src/app/jobs/new/SimpleJob.tsx
ADDED
|
@@ -0,0 +1,973 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
import { useMemo, useState } from 'react';
|
| 3 |
+
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
|
| 4 |
+
import { defaultDatasetConfig } from './jobConfig';
|
| 5 |
+
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
| 6 |
+
import { objectCopy } from '@/utils/basic';
|
| 7 |
+
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
| 8 |
+
import Card from '@/components/Card';
|
| 9 |
+
import { X } from 'lucide-react';
|
| 10 |
+
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
| 11 |
+
import {FlipHorizontal2, FlipVertical2} from "lucide-react";
|
| 12 |
+
import HFJobsWorkflow from '@/components/HFJobsWorkflow';
|
| 13 |
+
|
| 14 |
+
type Props = {
|
| 15 |
+
jobConfig: JobConfig;
|
| 16 |
+
setJobConfig: (value: any, key: string) => void;
|
| 17 |
+
status: 'idle' | 'saving' | 'success' | 'error';
|
| 18 |
+
handleSubmit: (event: React.FormEvent<HTMLFormElement>) => void;
|
| 19 |
+
runId: string | null;
|
| 20 |
+
gpuIDs: string | null;
|
| 21 |
+
setGpuIDs: (value: string | null) => void;
|
| 22 |
+
gpuList: any;
|
| 23 |
+
datasetOptions: any;
|
| 24 |
+
trainingBackend?: 'local' | 'hf-jobs';
|
| 25 |
+
setTrainingBackend?: (backend: 'local' | 'hf-jobs') => void;
|
| 26 |
+
hfJobSubmitted?: boolean;
|
| 27 |
+
onHFJobComplete?: (jobId: string, localJobId?: string) => void;
|
| 28 |
+
forceHFBackend?: boolean;
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
const isDev = process.env.NODE_ENV === 'development';
|
| 32 |
+
|
| 33 |
+
export default function SimpleJob({
|
| 34 |
+
jobConfig,
|
| 35 |
+
setJobConfig,
|
| 36 |
+
handleSubmit,
|
| 37 |
+
status,
|
| 38 |
+
runId,
|
| 39 |
+
gpuIDs,
|
| 40 |
+
setGpuIDs,
|
| 41 |
+
gpuList,
|
| 42 |
+
datasetOptions,
|
| 43 |
+
trainingBackend: parentTrainingBackend,
|
| 44 |
+
setTrainingBackend: parentSetTrainingBackend,
|
| 45 |
+
hfJobSubmitted,
|
| 46 |
+
onHFJobComplete,
|
| 47 |
+
forceHFBackend = false,
|
| 48 |
+
}: Props) {
|
| 49 |
+
const [localTrainingBackend, setLocalTrainingBackend] = useState(forceHFBackend ? 'hf-jobs' : 'local');
|
| 50 |
+
const trainingBackend = parentTrainingBackend || localTrainingBackend;
|
| 51 |
+
const setTrainingBackend = forceHFBackend
|
| 52 |
+
? (_: 'local' | 'hf-jobs') => undefined
|
| 53 |
+
: parentSetTrainingBackend || setLocalTrainingBackend;
|
| 54 |
+
const backendOptions = forceHFBackend
|
| 55 |
+
? [{ value: 'hf-jobs', label: 'HF Jobs (Cloud)' }]
|
| 56 |
+
: [
|
| 57 |
+
{ value: 'local', label: 'Local GPU' },
|
| 58 |
+
{ value: 'hf-jobs', label: 'HF Jobs (Cloud)' },
|
| 59 |
+
];
|
| 60 |
+
const modelArch = useMemo(() => {
|
| 61 |
+
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
| 62 |
+
}, [jobConfig.config.process[0].model.arch]);
|
| 63 |
+
|
| 64 |
+
const isVideoModel = !!(modelArch?.group === 'video');
|
| 65 |
+
|
| 66 |
+
const numTopCards = useMemo(() => {
|
| 67 |
+
let count = 4; // job settings, model config, target config, save config
|
| 68 |
+
if (modelArch?.additionalSections?.includes('model.multistage')) {
|
| 69 |
+
count += 1; // add multistage card
|
| 70 |
+
}
|
| 71 |
+
if (!modelArch?.disableSections?.includes('model.quantize')) {
|
| 72 |
+
count += 1; // add quantization card
|
| 73 |
+
}
|
| 74 |
+
return count;
|
| 75 |
+
|
| 76 |
+
}, [modelArch]);
|
| 77 |
+
|
| 78 |
+
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
| 79 |
+
|
| 80 |
+
if (numTopCards == 5) {
|
| 81 |
+
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
|
| 82 |
+
}
|
| 83 |
+
if (numTopCards == 6) {
|
| 84 |
+
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
|
| 88 |
+
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
|
| 89 |
+
if (!hasARA) {
|
| 90 |
+
return quantizationOptions;
|
| 91 |
+
}
|
| 92 |
+
let newQuantizationOptions = [
|
| 93 |
+
{
|
| 94 |
+
label: 'Standard',
|
| 95 |
+
options: [quantizationOptions[0], quantizationOptions[1]],
|
| 96 |
+
},
|
| 97 |
+
];
|
| 98 |
+
|
| 99 |
+
// add ARAs if they exist for the model
|
| 100 |
+
let ARAs: SelectOption[] = [];
|
| 101 |
+
if (modelArch.accuracyRecoveryAdapters) {
|
| 102 |
+
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
|
| 103 |
+
ARAs.push({ value, label });
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
if (ARAs.length > 0) {
|
| 107 |
+
newQuantizationOptions.push({
|
| 108 |
+
label: 'Accuracy Recovery Adapters',
|
| 109 |
+
options: ARAs,
|
| 110 |
+
});
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
let additionalQuantizationOptions: SelectOption[] = [];
|
| 114 |
+
// add the quantization options if they are not already included
|
| 115 |
+
for (let i = 2; i < quantizationOptions.length; i++) {
|
| 116 |
+
const option = quantizationOptions[i];
|
| 117 |
+
additionalQuantizationOptions.push(option);
|
| 118 |
+
}
|
| 119 |
+
if (additionalQuantizationOptions.length > 0) {
|
| 120 |
+
newQuantizationOptions.push({
|
| 121 |
+
label: 'Additional Quantization Options',
|
| 122 |
+
options: additionalQuantizationOptions,
|
| 123 |
+
});
|
| 124 |
+
}
|
| 125 |
+
return newQuantizationOptions;
|
| 126 |
+
}, [modelArch]);
|
| 127 |
+
|
| 128 |
+
return (
|
| 129 |
+
<>
|
| 130 |
+
<form onSubmit={handleSubmit} className="space-y-8">
|
| 131 |
+
<div className={topBarClass}>
|
| 132 |
+
<Card title="Job">
|
| 133 |
+
<TextInput
|
| 134 |
+
label="Training Name"
|
| 135 |
+
value={jobConfig.config.name}
|
| 136 |
+
docKey="config.name"
|
| 137 |
+
onChange={value => setJobConfig(value, 'config.name')}
|
| 138 |
+
placeholder="Enter training name"
|
| 139 |
+
disabled={runId !== null}
|
| 140 |
+
required
|
| 141 |
+
/>
|
| 142 |
+
<SelectInput
|
| 143 |
+
label="Training Backend"
|
| 144 |
+
value={trainingBackend}
|
| 145 |
+
onChange={(value) => {
|
| 146 |
+
setTrainingBackend(value);
|
| 147 |
+
}}
|
| 148 |
+
options={backendOptions}
|
| 149 |
+
disabled={forceHFBackend}
|
| 150 |
+
/>
|
| 151 |
+
{trainingBackend === 'local' && (
|
| 152 |
+
<SelectInput
|
| 153 |
+
label="GPU ID"
|
| 154 |
+
value={`${gpuIDs}`}
|
| 155 |
+
docKey="gpuids"
|
| 156 |
+
onChange={value => setGpuIDs(value)}
|
| 157 |
+
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
| 158 |
+
/>
|
| 159 |
+
)}
|
| 160 |
+
<TextInput
|
| 161 |
+
label="Trigger Word"
|
| 162 |
+
value={jobConfig.config.process[0].trigger_word || ''}
|
| 163 |
+
docKey="config.process[0].trigger_word"
|
| 164 |
+
onChange={(value: string | null) => {
|
| 165 |
+
if (value?.trim() === '') {
|
| 166 |
+
value = null;
|
| 167 |
+
}
|
| 168 |
+
setJobConfig(value, 'config.process[0].trigger_word');
|
| 169 |
+
}}
|
| 170 |
+
placeholder=""
|
| 171 |
+
required
|
| 172 |
+
/>
|
| 173 |
+
{trainingBackend === 'hf-jobs' && (
|
| 174 |
+
<div className={`mt-4 p-3 rounded ${
|
| 175 |
+
hfJobSubmitted
|
| 176 |
+
? 'bg-green-900/20 border border-green-700'
|
| 177 |
+
: 'bg-yellow-900/20 border border-yellow-700'
|
| 178 |
+
}`}>
|
| 179 |
+
<p className={`text-sm ${
|
| 180 |
+
hfJobSubmitted ? 'text-green-400' : 'text-yellow-400'
|
| 181 |
+
}`}>
|
| 182 |
+
{hfJobSubmitted
|
| 183 |
+
? '✓ HF Job already submitted! You can modify settings and resubmit if needed.'
|
| 184 |
+
: '⏳ HF Job ready for submission. Submit to the cloud below.'
|
| 185 |
+
}
|
| 186 |
+
</p>
|
| 187 |
+
</div>
|
| 188 |
+
)}
|
| 189 |
+
</Card>
|
| 190 |
+
|
| 191 |
+
{/* Model Configuration Section */}
|
| 192 |
+
<Card title="Model">
|
| 193 |
+
<SelectInput
|
| 194 |
+
label="Model Architecture"
|
| 195 |
+
value={jobConfig.config.process[0].model.arch}
|
| 196 |
+
onChange={value => {
|
| 197 |
+
const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch);
|
| 198 |
+
if (!currentArch || currentArch.name === value) {
|
| 199 |
+
return;
|
| 200 |
+
}
|
| 201 |
+
// update the defaults when a model is selected
|
| 202 |
+
const newArch = modelArchs.find(model => model.name === value);
|
| 203 |
+
|
| 204 |
+
// update vram setting
|
| 205 |
+
if (!newArch?.additionalSections?.includes('model.low_vram')) {
|
| 206 |
+
setJobConfig(false, 'config.process[0].model.low_vram');
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// revert defaults from previous model
|
| 210 |
+
for (const key in currentArch.defaults) {
|
| 211 |
+
setJobConfig(currentArch.defaults[key][1], key);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if (newArch?.defaults) {
|
| 215 |
+
for (const key in newArch.defaults) {
|
| 216 |
+
setJobConfig(newArch.defaults[key][0], key);
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
// set new model
|
| 220 |
+
setJobConfig(value, 'config.process[0].model.arch');
|
| 221 |
+
|
| 222 |
+
// update datasets
|
| 223 |
+
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
| 224 |
+
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
|
| 225 |
+
const controls = newArch?.controls ?? [];
|
| 226 |
+
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
| 227 |
+
const newDataset = objectCopy(dataset);
|
| 228 |
+
newDataset.controls = controls;
|
| 229 |
+
if (!hasControlPath) {
|
| 230 |
+
newDataset.control_path = null; // reset control path if not applicable
|
| 231 |
+
}
|
| 232 |
+
if (!hasNumFrames) {
|
| 233 |
+
newDataset.num_frames = 1; // reset num_frames if not applicable
|
| 234 |
+
}
|
| 235 |
+
return newDataset;
|
| 236 |
+
});
|
| 237 |
+
setJobConfig(datasets, 'config.process[0].datasets');
|
| 238 |
+
|
| 239 |
+
// update samples
|
| 240 |
+
const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false;
|
| 241 |
+
const samples = jobConfig.config.process[0].sample.samples.map(sample => {
|
| 242 |
+
const newSample = objectCopy(sample);
|
| 243 |
+
if (!hasSampleCtrlImg) {
|
| 244 |
+
delete newSample.ctrl_img; // remove ctrl_img if not applicable
|
| 245 |
+
}
|
| 246 |
+
return newSample;
|
| 247 |
+
});
|
| 248 |
+
setJobConfig(samples, 'config.process[0].sample.samples');
|
| 249 |
+
}}
|
| 250 |
+
options={groupedModelOptions}
|
| 251 |
+
/>
|
| 252 |
+
<TextInput
|
| 253 |
+
label="Name or Path"
|
| 254 |
+
value={jobConfig.config.process[0].model.name_or_path}
|
| 255 |
+
docKey="config.process[0].model.name_or_path"
|
| 256 |
+
onChange={(value: string | null) => {
|
| 257 |
+
if (value?.trim() === '') {
|
| 258 |
+
value = null;
|
| 259 |
+
}
|
| 260 |
+
setJobConfig(value, 'config.process[0].model.name_or_path');
|
| 261 |
+
}}
|
| 262 |
+
placeholder=""
|
| 263 |
+
required
|
| 264 |
+
/>
|
| 265 |
+
{modelArch?.additionalSections?.includes('model.low_vram') && (
|
| 266 |
+
<FormGroup label="Options">
|
| 267 |
+
<Checkbox
|
| 268 |
+
label="Low VRAM"
|
| 269 |
+
checked={jobConfig.config.process[0].model.low_vram}
|
| 270 |
+
onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')}
|
| 271 |
+
/>
|
| 272 |
+
</FormGroup>
|
| 273 |
+
)}
|
| 274 |
+
</Card>
|
| 275 |
+
{modelArch?.disableSections?.includes('model.quantize') ? null : (
|
| 276 |
+
<Card title="Quantization">
|
| 277 |
+
<SelectInput
|
| 278 |
+
label="Transformer"
|
| 279 |
+
value={jobConfig.config.process[0].model.quantize ? jobConfig.config.process[0].model.qtype : ''}
|
| 280 |
+
onChange={value => {
|
| 281 |
+
if (value === '') {
|
| 282 |
+
setJobConfig(false, 'config.process[0].model.quantize');
|
| 283 |
+
value = defaultQtype;
|
| 284 |
+
} else {
|
| 285 |
+
setJobConfig(true, 'config.process[0].model.quantize');
|
| 286 |
+
}
|
| 287 |
+
setJobConfig(value, 'config.process[0].model.qtype');
|
| 288 |
+
}}
|
| 289 |
+
options={transformerQuantizationOptions}
|
| 290 |
+
/>
|
| 291 |
+
<SelectInput
|
| 292 |
+
label="Text Encoder"
|
| 293 |
+
value={jobConfig.config.process[0].model.quantize_te ? jobConfig.config.process[0].model.qtype_te : ''}
|
| 294 |
+
onChange={value => {
|
| 295 |
+
if (value === '') {
|
| 296 |
+
setJobConfig(false, 'config.process[0].model.quantize_te');
|
| 297 |
+
value = defaultQtype;
|
| 298 |
+
} else {
|
| 299 |
+
setJobConfig(true, 'config.process[0].model.quantize_te');
|
| 300 |
+
}
|
| 301 |
+
setJobConfig(value, 'config.process[0].model.qtype_te');
|
| 302 |
+
}}
|
| 303 |
+
options={quantizationOptions}
|
| 304 |
+
/>
|
| 305 |
+
</Card>
|
| 306 |
+
)}
|
| 307 |
+
{modelArch?.additionalSections?.includes('model.multistage') && (
|
| 308 |
+
<Card title="Multistage">
|
| 309 |
+
<FormGroup label="Stages to Train" docKey={'model.multistage'}>
|
| 310 |
+
<Checkbox
|
| 311 |
+
label="High Noise"
|
| 312 |
+
checked={jobConfig.config.process[0].model.model_kwargs?.train_high_noise || false}
|
| 313 |
+
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')}
|
| 314 |
+
/>
|
| 315 |
+
<Checkbox
|
| 316 |
+
label="Low Noise"
|
| 317 |
+
checked={jobConfig.config.process[0].model.model_kwargs?.train_low_noise || false}
|
| 318 |
+
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')}
|
| 319 |
+
/>
|
| 320 |
+
</FormGroup>
|
| 321 |
+
<NumberInput
|
| 322 |
+
label="Switch Every"
|
| 323 |
+
value={jobConfig.config.process[0].train.switch_boundary_every}
|
| 324 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
|
| 325 |
+
placeholder="eg. 1"
|
| 326 |
+
docKey={'train.switch_boundary_every'}
|
| 327 |
+
min={1}
|
| 328 |
+
required
|
| 329 |
+
/>
|
| 330 |
+
</Card>
|
| 331 |
+
)}
|
| 332 |
+
<Card title="Target">
|
| 333 |
+
<SelectInput
|
| 334 |
+
label="Target Type"
|
| 335 |
+
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
| 336 |
+
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
|
| 337 |
+
options={[
|
| 338 |
+
{ value: 'lora', label: 'LoRA' },
|
| 339 |
+
{ value: 'lokr', label: 'LoKr' },
|
| 340 |
+
]}
|
| 341 |
+
/>
|
| 342 |
+
{jobConfig.config.process[0].network?.type == 'lokr' && (
|
| 343 |
+
<SelectInput
|
| 344 |
+
label="LoKr Factor"
|
| 345 |
+
value={`${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
|
| 346 |
+
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
|
| 347 |
+
options={[
|
| 348 |
+
{ value: '-1', label: 'Auto' },
|
| 349 |
+
{ value: '4', label: '4' },
|
| 350 |
+
{ value: '8', label: '8' },
|
| 351 |
+
{ value: '16', label: '16' },
|
| 352 |
+
{ value: '32', label: '32' },
|
| 353 |
+
]}
|
| 354 |
+
/>
|
| 355 |
+
)}
|
| 356 |
+
{jobConfig.config.process[0].network?.type == 'lora' && (
|
| 357 |
+
<>
|
| 358 |
+
<NumberInput
|
| 359 |
+
label="Linear Rank"
|
| 360 |
+
value={jobConfig.config.process[0].network.linear}
|
| 361 |
+
onChange={value => {
|
| 362 |
+
console.log('onChange', value);
|
| 363 |
+
setJobConfig(value, 'config.process[0].network.linear');
|
| 364 |
+
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
| 365 |
+
}}
|
| 366 |
+
placeholder="eg. 16"
|
| 367 |
+
min={0}
|
| 368 |
+
max={1024}
|
| 369 |
+
required
|
| 370 |
+
/>
|
| 371 |
+
{modelArch?.disableSections?.includes('network.conv') ? null : (
|
| 372 |
+
<NumberInput
|
| 373 |
+
label="Conv Rank"
|
| 374 |
+
value={jobConfig.config.process[0].network.conv}
|
| 375 |
+
onChange={value => {
|
| 376 |
+
console.log('onChange', value);
|
| 377 |
+
setJobConfig(value, 'config.process[0].network.conv');
|
| 378 |
+
setJobConfig(value, 'config.process[0].network.conv_alpha');
|
| 379 |
+
}}
|
| 380 |
+
placeholder="eg. 16"
|
| 381 |
+
min={0}
|
| 382 |
+
max={1024}
|
| 383 |
+
/>
|
| 384 |
+
)}
|
| 385 |
+
</>
|
| 386 |
+
)}
|
| 387 |
+
</Card>
|
| 388 |
+
<Card title="Save">
|
| 389 |
+
<SelectInput
|
| 390 |
+
label="Data Type"
|
| 391 |
+
value={jobConfig.config.process[0].save.dtype}
|
| 392 |
+
onChange={value => setJobConfig(value, 'config.process[0].save.dtype')}
|
| 393 |
+
options={[
|
| 394 |
+
{ value: 'bf16', label: 'BF16' },
|
| 395 |
+
{ value: 'fp16', label: 'FP16' },
|
| 396 |
+
{ value: 'fp32', label: 'FP32' },
|
| 397 |
+
]}
|
| 398 |
+
/>
|
| 399 |
+
<NumberInput
|
| 400 |
+
label="Save Every"
|
| 401 |
+
value={jobConfig.config.process[0].save.save_every}
|
| 402 |
+
onChange={value => setJobConfig(value, 'config.process[0].save.save_every')}
|
| 403 |
+
placeholder="eg. 250"
|
| 404 |
+
min={1}
|
| 405 |
+
required
|
| 406 |
+
/>
|
| 407 |
+
<NumberInput
|
| 408 |
+
label="Max Step Saves to Keep"
|
| 409 |
+
value={jobConfig.config.process[0].save.max_step_saves_to_keep}
|
| 410 |
+
onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')}
|
| 411 |
+
placeholder="eg. 4"
|
| 412 |
+
min={1}
|
| 413 |
+
required
|
| 414 |
+
/>
|
| 415 |
+
</Card>
|
| 416 |
+
</div>
|
| 417 |
+
<div>
|
| 418 |
+
<Card title="Training">
|
| 419 |
+
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
| 420 |
+
<div>
|
| 421 |
+
<NumberInput
|
| 422 |
+
label="Batch Size"
|
| 423 |
+
value={jobConfig.config.process[0].train.batch_size}
|
| 424 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
|
| 425 |
+
placeholder="eg. 4"
|
| 426 |
+
min={1}
|
| 427 |
+
required
|
| 428 |
+
/>
|
| 429 |
+
<NumberInput
|
| 430 |
+
label="Gradient Accumulation"
|
| 431 |
+
className="pt-2"
|
| 432 |
+
value={jobConfig.config.process[0].train.gradient_accumulation}
|
| 433 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')}
|
| 434 |
+
placeholder="eg. 1"
|
| 435 |
+
min={1}
|
| 436 |
+
required
|
| 437 |
+
/>
|
| 438 |
+
<NumberInput
|
| 439 |
+
label="Steps"
|
| 440 |
+
className="pt-2"
|
| 441 |
+
value={jobConfig.config.process[0].train.steps}
|
| 442 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.steps')}
|
| 443 |
+
placeholder="eg. 2000"
|
| 444 |
+
min={1}
|
| 445 |
+
required
|
| 446 |
+
/>
|
| 447 |
+
</div>
|
| 448 |
+
<div>
|
| 449 |
+
<SelectInput
|
| 450 |
+
label="Optimizer"
|
| 451 |
+
value={jobConfig.config.process[0].train.optimizer}
|
| 452 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
|
| 453 |
+
options={[
|
| 454 |
+
{ value: 'adamw8bit', label: 'AdamW8Bit' },
|
| 455 |
+
{ value: 'adafactor', label: 'Adafactor' },
|
| 456 |
+
]}
|
| 457 |
+
/>
|
| 458 |
+
<NumberInput
|
| 459 |
+
label="Learning Rate"
|
| 460 |
+
className="pt-2"
|
| 461 |
+
value={jobConfig.config.process[0].train.lr}
|
| 462 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.lr')}
|
| 463 |
+
placeholder="eg. 0.0001"
|
| 464 |
+
min={0}
|
| 465 |
+
required
|
| 466 |
+
/>
|
| 467 |
+
<NumberInput
|
| 468 |
+
label="Weight Decay"
|
| 469 |
+
className="pt-2"
|
| 470 |
+
value={jobConfig.config.process[0].train.optimizer_params.weight_decay}
|
| 471 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')}
|
| 472 |
+
placeholder="eg. 0.0001"
|
| 473 |
+
min={0}
|
| 474 |
+
required
|
| 475 |
+
/>
|
| 476 |
+
</div>
|
| 477 |
+
<div>
|
| 478 |
+
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
|
| 479 |
+
<SelectInput
|
| 480 |
+
label="Timestep Type"
|
| 481 |
+
value={jobConfig.config.process[0].train.timestep_type}
|
| 482 |
+
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
|
| 483 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
| 484 |
+
options={[
|
| 485 |
+
{ value: 'sigmoid', label: 'Sigmoid' },
|
| 486 |
+
{ value: 'linear', label: 'Linear' },
|
| 487 |
+
{ value: 'shift', label: 'Shift' },
|
| 488 |
+
{ value: 'weighted', label: 'Weighted' },
|
| 489 |
+
]}
|
| 490 |
+
/>
|
| 491 |
+
)}
|
| 492 |
+
<SelectInput
|
| 493 |
+
label="Timestep Bias"
|
| 494 |
+
className="pt-2"
|
| 495 |
+
value={jobConfig.config.process[0].train.content_or_style}
|
| 496 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')}
|
| 497 |
+
options={[
|
| 498 |
+
{ value: 'balanced', label: 'Balanced' },
|
| 499 |
+
{ value: 'content', label: 'High Noise' },
|
| 500 |
+
{ value: 'style', label: 'Low Noise' },
|
| 501 |
+
]}
|
| 502 |
+
/>
|
| 503 |
+
<SelectInput
|
| 504 |
+
label="Noise Scheduler"
|
| 505 |
+
className="pt-2"
|
| 506 |
+
value={jobConfig.config.process[0].train.noise_scheduler}
|
| 507 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
| 508 |
+
options={[
|
| 509 |
+
{ value: 'flowmatch', label: 'FlowMatch' },
|
| 510 |
+
{ value: 'ddpm', label: 'DDPM' },
|
| 511 |
+
]}
|
| 512 |
+
/>
|
| 513 |
+
</div>
|
| 514 |
+
<div>
|
| 515 |
+
<FormGroup label="EMA (Exponential Moving Average)">
|
| 516 |
+
<Checkbox
|
| 517 |
+
label="Use EMA"
|
| 518 |
+
className="pt-1"
|
| 519 |
+
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
|
| 520 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
| 521 |
+
/>
|
| 522 |
+
</FormGroup>
|
| 523 |
+
{jobConfig.config.process[0].train.ema_config?.use_ema && (
|
| 524 |
+
<NumberInput
|
| 525 |
+
label="EMA Decay"
|
| 526 |
+
className="pt-2"
|
| 527 |
+
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
| 528 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
| 529 |
+
placeholder="eg. 0.99"
|
| 530 |
+
min={0}
|
| 531 |
+
/>
|
| 532 |
+
)}
|
| 533 |
+
|
| 534 |
+
<FormGroup label="Text Encoder Optimizations" className="pt-2">
|
| 535 |
+
<Checkbox
|
| 536 |
+
label="Unload TE"
|
| 537 |
+
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
| 538 |
+
docKey={'train.unload_text_encoder'}
|
| 539 |
+
onChange={value => {
|
| 540 |
+
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
|
| 541 |
+
if (value) {
|
| 542 |
+
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
| 543 |
+
}
|
| 544 |
+
}}
|
| 545 |
+
/>
|
| 546 |
+
<Checkbox
|
| 547 |
+
label="Cache Text Embeddings"
|
| 548 |
+
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
|
| 549 |
+
docKey={'train.cache_text_embeddings'}
|
| 550 |
+
onChange={value => {
|
| 551 |
+
setJobConfig(value, 'config.process[0].train.cache_text_embeddings');
|
| 552 |
+
if (value) {
|
| 553 |
+
setJobConfig(false, 'config.process[0].train.unload_text_encoder');
|
| 554 |
+
}
|
| 555 |
+
}}
|
| 556 |
+
/>
|
| 557 |
+
</FormGroup>
|
| 558 |
+
</div>
|
| 559 |
+
<div>
|
| 560 |
+
<FormGroup label="Regularization">
|
| 561 |
+
<Checkbox
|
| 562 |
+
label="Differtial Output Preservation"
|
| 563 |
+
className="pt-1"
|
| 564 |
+
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
| 565 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
| 566 |
+
/>
|
| 567 |
+
</FormGroup>
|
| 568 |
+
{jobConfig.config.process[0].train.diff_output_preservation && (
|
| 569 |
+
<>
|
| 570 |
+
<NumberInput
|
| 571 |
+
label="DOP Loss Multiplier"
|
| 572 |
+
className="pt-2"
|
| 573 |
+
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
| 574 |
+
onChange={value =>
|
| 575 |
+
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
| 576 |
+
}
|
| 577 |
+
placeholder="eg. 1.0"
|
| 578 |
+
min={0}
|
| 579 |
+
/>
|
| 580 |
+
<TextInput
|
| 581 |
+
label="DOP Preservation Class"
|
| 582 |
+
className="pt-2"
|
| 583 |
+
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
| 584 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
| 585 |
+
placeholder="eg. woman"
|
| 586 |
+
/>
|
| 587 |
+
</>
|
| 588 |
+
)}
|
| 589 |
+
</div>
|
| 590 |
+
</div>
|
| 591 |
+
</Card>
|
| 592 |
+
</div>
|
| 593 |
+
<div>
|
| 594 |
+
<Card title="Datasets">
|
| 595 |
+
<>
|
| 596 |
+
{jobConfig.config.process[0].datasets.map((dataset, i) => (
|
| 597 |
+
<div key={i} className="p-4 rounded-lg bg-gray-800 relative">
|
| 598 |
+
<button
|
| 599 |
+
type="button"
|
| 600 |
+
onClick={() =>
|
| 601 |
+
setJobConfig(
|
| 602 |
+
jobConfig.config.process[0].datasets.filter((_, index) => index !== i),
|
| 603 |
+
'config.process[0].datasets',
|
| 604 |
+
)
|
| 605 |
+
}
|
| 606 |
+
className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors"
|
| 607 |
+
>
|
| 608 |
+
<X />
|
| 609 |
+
</button>
|
| 610 |
+
<h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2>
|
| 611 |
+
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
| 612 |
+
<div>
|
| 613 |
+
<SelectInput
|
| 614 |
+
label="Dataset"
|
| 615 |
+
value={dataset.folder_path}
|
| 616 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
| 617 |
+
options={datasetOptions}
|
| 618 |
+
/>
|
| 619 |
+
{modelArch?.additionalSections?.includes('datasets.control_path') && (
|
| 620 |
+
<SelectInput
|
| 621 |
+
label="Control Dataset"
|
| 622 |
+
docKey="datasets.control_path"
|
| 623 |
+
value={dataset.control_path ?? ''}
|
| 624 |
+
className="pt-2"
|
| 625 |
+
onChange={value =>
|
| 626 |
+
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
|
| 627 |
+
}
|
| 628 |
+
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
| 629 |
+
/>
|
| 630 |
+
)}
|
| 631 |
+
<NumberInput
|
| 632 |
+
label="LoRA Weight"
|
| 633 |
+
value={dataset.network_weight}
|
| 634 |
+
className="pt-2"
|
| 635 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)}
|
| 636 |
+
placeholder="eg. 1.0"
|
| 637 |
+
/>
|
| 638 |
+
</div>
|
| 639 |
+
<div>
|
| 640 |
+
<TextInput
|
| 641 |
+
label="Default Caption"
|
| 642 |
+
value={dataset.default_caption}
|
| 643 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
|
| 644 |
+
placeholder="eg. A photo of a cat"
|
| 645 |
+
/>
|
| 646 |
+
<NumberInput
|
| 647 |
+
label="Caption Dropout Rate"
|
| 648 |
+
className="pt-2"
|
| 649 |
+
value={dataset.caption_dropout_rate}
|
| 650 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)}
|
| 651 |
+
placeholder="eg. 0.05"
|
| 652 |
+
min={0}
|
| 653 |
+
required
|
| 654 |
+
/>
|
| 655 |
+
{modelArch?.additionalSections?.includes('datasets.num_frames') && (
|
| 656 |
+
<NumberInput
|
| 657 |
+
label="Num Frames"
|
| 658 |
+
className="pt-2"
|
| 659 |
+
docKey="datasets.num_frames"
|
| 660 |
+
value={dataset.num_frames}
|
| 661 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)}
|
| 662 |
+
placeholder="eg. 41"
|
| 663 |
+
min={1}
|
| 664 |
+
required
|
| 665 |
+
/>
|
| 666 |
+
)}
|
| 667 |
+
</div>
|
| 668 |
+
<div>
|
| 669 |
+
<FormGroup label="Settings" className="">
|
| 670 |
+
<Checkbox
|
| 671 |
+
label="Cache Latents"
|
| 672 |
+
checked={dataset.cache_latents_to_disk || false}
|
| 673 |
+
onChange={value =>
|
| 674 |
+
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
|
| 675 |
+
}
|
| 676 |
+
/>
|
| 677 |
+
<Checkbox
|
| 678 |
+
label="Is Regularization"
|
| 679 |
+
checked={dataset.is_reg || false}
|
| 680 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
| 681 |
+
/>
|
| 682 |
+
{modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
| 683 |
+
<Checkbox
|
| 684 |
+
label="Do I2V"
|
| 685 |
+
checked={dataset.do_i2v || false}
|
| 686 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
| 687 |
+
docKey="datasets.do_i2v"
|
| 688 |
+
/>
|
| 689 |
+
)}
|
| 690 |
+
</FormGroup>
|
| 691 |
+
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
|
| 692 |
+
<Checkbox
|
| 693 |
+
label={<>Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" /></>}
|
| 694 |
+
checked={dataset.flip_x || false}
|
| 695 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)}
|
| 696 |
+
/>
|
| 697 |
+
<Checkbox
|
| 698 |
+
label={<>Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" /></>}
|
| 699 |
+
checked={dataset.flip_y || false}
|
| 700 |
+
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)}
|
| 701 |
+
/>
|
| 702 |
+
</FormGroup>
|
| 703 |
+
</div>
|
| 704 |
+
<div>
|
| 705 |
+
<FormGroup label="Resolutions" className="pt-2">
|
| 706 |
+
<div className="grid grid-cols-2 gap-2">
|
| 707 |
+
{[
|
| 708 |
+
[256, 512, 768],
|
| 709 |
+
[1024, 1280, 1536],
|
| 710 |
+
].map(resGroup => (
|
| 711 |
+
<div key={resGroup[0]} className="space-y-2">
|
| 712 |
+
{resGroup.map(res => (
|
| 713 |
+
<Checkbox
|
| 714 |
+
key={res}
|
| 715 |
+
label={res.toString()}
|
| 716 |
+
checked={dataset.resolution.includes(res)}
|
| 717 |
+
onChange={value => {
|
| 718 |
+
const resolutions = dataset.resolution.includes(res)
|
| 719 |
+
? dataset.resolution.filter(r => r !== res)
|
| 720 |
+
: [...dataset.resolution, res];
|
| 721 |
+
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
|
| 722 |
+
}}
|
| 723 |
+
/>
|
| 724 |
+
))}
|
| 725 |
+
</div>
|
| 726 |
+
))}
|
| 727 |
+
</div>
|
| 728 |
+
</FormGroup>
|
| 729 |
+
</div>
|
| 730 |
+
</div>
|
| 731 |
+
</div>
|
| 732 |
+
))}
|
| 733 |
+
<button
|
| 734 |
+
type="button"
|
| 735 |
+
onClick={() => {
|
| 736 |
+
const newDataset = objectCopy(defaultDatasetConfig);
|
| 737 |
+
// automaticallt add the controls for a new dataset
|
| 738 |
+
const controls = modelArch?.controls ?? [];
|
| 739 |
+
newDataset.controls = controls;
|
| 740 |
+
setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets');
|
| 741 |
+
}}
|
| 742 |
+
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
| 743 |
+
>
|
| 744 |
+
Add Dataset
|
| 745 |
+
</button>
|
| 746 |
+
</>
|
| 747 |
+
</Card>
|
| 748 |
+
</div>
|
| 749 |
+
<div>
|
| 750 |
+
<Card title="Sample">
|
| 751 |
+
<div
|
| 752 |
+
className={
|
| 753 |
+
isVideoModel
|
| 754 |
+
? 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6'
|
| 755 |
+
: 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6'
|
| 756 |
+
}
|
| 757 |
+
>
|
| 758 |
+
<div>
|
| 759 |
+
<NumberInput
|
| 760 |
+
label="Sample Every"
|
| 761 |
+
value={jobConfig.config.process[0].sample.sample_every}
|
| 762 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')}
|
| 763 |
+
placeholder="eg. 250"
|
| 764 |
+
min={1}
|
| 765 |
+
required
|
| 766 |
+
/>
|
| 767 |
+
<SelectInput
|
| 768 |
+
label="Sampler"
|
| 769 |
+
className="pt-2"
|
| 770 |
+
value={jobConfig.config.process[0].sample.sampler}
|
| 771 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
|
| 772 |
+
options={[
|
| 773 |
+
{ value: 'flowmatch', label: 'FlowMatch' },
|
| 774 |
+
{ value: 'ddpm', label: 'DDPM' },
|
| 775 |
+
]}
|
| 776 |
+
/>
|
| 777 |
+
<NumberInput
|
| 778 |
+
label="Guidance Scale"
|
| 779 |
+
value={jobConfig.config.process[0].sample.guidance_scale}
|
| 780 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
|
| 781 |
+
placeholder="eg. 1.0"
|
| 782 |
+
className="pt-2"
|
| 783 |
+
min={0}
|
| 784 |
+
required
|
| 785 |
+
/>
|
| 786 |
+
<NumberInput
|
| 787 |
+
label="Sample Steps"
|
| 788 |
+
value={jobConfig.config.process[0].sample.sample_steps}
|
| 789 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')}
|
| 790 |
+
placeholder="eg. 1"
|
| 791 |
+
className="pt-2"
|
| 792 |
+
min={1}
|
| 793 |
+
required
|
| 794 |
+
/>
|
| 795 |
+
</div>
|
| 796 |
+
<div>
|
| 797 |
+
<NumberInput
|
| 798 |
+
label="Width"
|
| 799 |
+
value={jobConfig.config.process[0].sample.width}
|
| 800 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
|
| 801 |
+
placeholder="eg. 1024"
|
| 802 |
+
min={0}
|
| 803 |
+
required
|
| 804 |
+
/>
|
| 805 |
+
<NumberInput
|
| 806 |
+
label="Height"
|
| 807 |
+
value={jobConfig.config.process[0].sample.height}
|
| 808 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
|
| 809 |
+
placeholder="eg. 1024"
|
| 810 |
+
className="pt-2"
|
| 811 |
+
min={0}
|
| 812 |
+
required
|
| 813 |
+
/>
|
| 814 |
+
{isVideoModel && (
|
| 815 |
+
<div>
|
| 816 |
+
<NumberInput
|
| 817 |
+
label="Num Frames"
|
| 818 |
+
value={jobConfig.config.process[0].sample.num_frames}
|
| 819 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
| 820 |
+
placeholder="eg. 0"
|
| 821 |
+
className="pt-2"
|
| 822 |
+
min={0}
|
| 823 |
+
required
|
| 824 |
+
/>
|
| 825 |
+
<NumberInput
|
| 826 |
+
label="FPS"
|
| 827 |
+
value={jobConfig.config.process[0].sample.fps}
|
| 828 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
| 829 |
+
placeholder="eg. 0"
|
| 830 |
+
className="pt-2"
|
| 831 |
+
min={0}
|
| 832 |
+
required
|
| 833 |
+
/>
|
| 834 |
+
</div>
|
| 835 |
+
)}
|
| 836 |
+
</div>
|
| 837 |
+
|
| 838 |
+
<div>
|
| 839 |
+
<NumberInput
|
| 840 |
+
label="Seed"
|
| 841 |
+
value={jobConfig.config.process[0].sample.seed}
|
| 842 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.seed')}
|
| 843 |
+
placeholder="eg. 0"
|
| 844 |
+
min={0}
|
| 845 |
+
required
|
| 846 |
+
/>
|
| 847 |
+
<Checkbox
|
| 848 |
+
label="Walk Seed"
|
| 849 |
+
className="pt-4 pl-2"
|
| 850 |
+
checked={jobConfig.config.process[0].sample.walk_seed}
|
| 851 |
+
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
|
| 852 |
+
/>
|
| 853 |
+
</div>
|
| 854 |
+
<div>
|
| 855 |
+
<FormGroup label="Advanced Sampling" className="pt-2">
|
| 856 |
+
<div>
|
| 857 |
+
<Checkbox
|
| 858 |
+
label="Skip First Sample"
|
| 859 |
+
className="pt-4"
|
| 860 |
+
checked={jobConfig.config.process[0].train.skip_first_sample || false}
|
| 861 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.skip_first_sample')}
|
| 862 |
+
/>
|
| 863 |
+
</div>
|
| 864 |
+
<div>
|
| 865 |
+
<Checkbox
|
| 866 |
+
label="Disable Sampling"
|
| 867 |
+
className="pt-1"
|
| 868 |
+
checked={jobConfig.config.process[0].train.disable_sampling || false}
|
| 869 |
+
onChange={value => setJobConfig(value, 'config.process[0].train.disable_sampling')}
|
| 870 |
+
/>
|
| 871 |
+
</div>
|
| 872 |
+
</FormGroup>
|
| 873 |
+
</div>
|
| 874 |
+
</div>
|
| 875 |
+
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.samples.length})`} className="pt-2">
|
| 876 |
+
<div></div>
|
| 877 |
+
</FormGroup>
|
| 878 |
+
{jobConfig.config.process[0].sample.samples.map((sample, i) => (
|
| 879 |
+
<div key={i} className="rounded-lg pl-4 pr-1 mb-4 bg-gray-950">
|
| 880 |
+
<div className="flex items-center space-x-2">
|
| 881 |
+
<div className="flex-1">
|
| 882 |
+
<div className="flex">
|
| 883 |
+
<div className="flex-1">
|
| 884 |
+
<TextInput
|
| 885 |
+
label={`Prompt`}
|
| 886 |
+
value={sample.prompt}
|
| 887 |
+
onChange={value => setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)}
|
| 888 |
+
placeholder="Enter prompt"
|
| 889 |
+
required
|
| 890 |
+
/>
|
| 891 |
+
</div>
|
| 892 |
+
|
| 893 |
+
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
| 894 |
+
<div
|
| 895 |
+
className="h-14 w-14 mt-2 ml-4 border border-gray-500 flex items-center justify-center rounded cursor-pointer hover:bg-gray-700 transition-colors"
|
| 896 |
+
style={{
|
| 897 |
+
backgroundImage: sample.ctrl_img
|
| 898 |
+
? `url(${`/api/img/${encodeURIComponent(sample.ctrl_img)}`})`
|
| 899 |
+
: 'none',
|
| 900 |
+
backgroundSize: 'cover',
|
| 901 |
+
backgroundPosition: 'center',
|
| 902 |
+
marginBottom: '-1rem',
|
| 903 |
+
}}
|
| 904 |
+
onClick={() => {
|
| 905 |
+
openAddImageModal(imagePath => {
|
| 906 |
+
console.log('Selected image path:', imagePath);
|
| 907 |
+
if (!imagePath) return;
|
| 908 |
+
setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`);
|
| 909 |
+
});
|
| 910 |
+
}}
|
| 911 |
+
>
|
| 912 |
+
{!sample.ctrl_img && (
|
| 913 |
+
<div className="text-gray-400 text-xs text-center font-bold">Add Control Image</div>
|
| 914 |
+
)}
|
| 915 |
+
</div>
|
| 916 |
+
)}
|
| 917 |
+
</div>
|
| 918 |
+
<div className="pb-4"></div>
|
| 919 |
+
</div>
|
| 920 |
+
<div>
|
| 921 |
+
<button
|
| 922 |
+
type="button"
|
| 923 |
+
onClick={() =>
|
| 924 |
+
setJobConfig(
|
| 925 |
+
jobConfig.config.process[0].sample.samples.filter((_, index) => index !== i),
|
| 926 |
+
'config.process[0].sample.samples',
|
| 927 |
+
)
|
| 928 |
+
}
|
| 929 |
+
className="rounded-full p-1 text-sm"
|
| 930 |
+
>
|
| 931 |
+
<X />
|
| 932 |
+
</button>
|
| 933 |
+
</div>
|
| 934 |
+
</div>
|
| 935 |
+
</div>
|
| 936 |
+
))}
|
| 937 |
+
<button
|
| 938 |
+
type="button"
|
| 939 |
+
onClick={() =>
|
| 940 |
+
setJobConfig(
|
| 941 |
+
[...jobConfig.config.process[0].sample.samples, { prompt: '' }],
|
| 942 |
+
'config.process[0].sample.samples',
|
| 943 |
+
)
|
| 944 |
+
}
|
| 945 |
+
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
| 946 |
+
>
|
| 947 |
+
Add Prompt
|
| 948 |
+
</button>
|
| 949 |
+
</Card>
|
| 950 |
+
</div>
|
| 951 |
+
|
| 952 |
+
{status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
|
| 953 |
+
{status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
|
| 954 |
+
</form>
|
| 955 |
+
|
| 956 |
+
{trainingBackend === 'hf-jobs' && (
|
| 957 |
+
<div className="mt-8">
|
| 958 |
+
<HFJobsWorkflow
|
| 959 |
+
jobConfig={jobConfig}
|
| 960 |
+
onComplete={(jobId, localJobId) => {
|
| 961 |
+
console.log('HF Job submitted:', jobId, 'Local job ID:', localJobId);
|
| 962 |
+
if (onHFJobComplete) {
|
| 963 |
+
onHFJobComplete(jobId, localJobId);
|
| 964 |
+
}
|
| 965 |
+
}}
|
| 966 |
+
/>
|
| 967 |
+
</div>
|
| 968 |
+
)}
|
| 969 |
+
|
| 970 |
+
<AddSingleImageModal />
|
| 971 |
+
</>
|
| 972 |
+
);
|
| 973 |
+
}
|
src/app/jobs/new/jobConfig.ts
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { JobConfig, DatasetConfig } from '@/types';
|
| 2 |
+
|
| 3 |
+
export const defaultDatasetConfig: DatasetConfig = {
|
| 4 |
+
folder_path: '/path/to/images/folder',
|
| 5 |
+
control_path: null,
|
| 6 |
+
mask_path: null,
|
| 7 |
+
mask_min_value: 0.1,
|
| 8 |
+
default_caption: '',
|
| 9 |
+
caption_ext: 'txt',
|
| 10 |
+
caption_dropout_rate: 0.05,
|
| 11 |
+
cache_latents_to_disk: false,
|
| 12 |
+
is_reg: false,
|
| 13 |
+
network_weight: 1,
|
| 14 |
+
resolution: [512, 768, 1024],
|
| 15 |
+
controls: [],
|
| 16 |
+
shrink_video_to_frames: true,
|
| 17 |
+
num_frames: 1,
|
| 18 |
+
do_i2v: true,
|
| 19 |
+
flip_x: false,
|
| 20 |
+
flip_y: false,
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
export const defaultJobConfig: JobConfig = {
|
| 24 |
+
job: 'extension',
|
| 25 |
+
config: {
|
| 26 |
+
name: 'my_first_lora_v1',
|
| 27 |
+
process: [
|
| 28 |
+
{
|
| 29 |
+
type: 'ui_trainer',
|
| 30 |
+
training_folder: 'output',
|
| 31 |
+
sqlite_db_path: './aitk_db.db',
|
| 32 |
+
device: 'cuda',
|
| 33 |
+
trigger_word: null,
|
| 34 |
+
performance_log_every: 10,
|
| 35 |
+
network: {
|
| 36 |
+
type: 'lora',
|
| 37 |
+
linear: 32,
|
| 38 |
+
linear_alpha: 32,
|
| 39 |
+
conv: 16,
|
| 40 |
+
conv_alpha: 16,
|
| 41 |
+
lokr_full_rank: true,
|
| 42 |
+
lokr_factor: -1,
|
| 43 |
+
network_kwargs: {
|
| 44 |
+
ignore_if_contains: [],
|
| 45 |
+
},
|
| 46 |
+
},
|
| 47 |
+
save: {
|
| 48 |
+
dtype: 'bf16',
|
| 49 |
+
save_every: 250,
|
| 50 |
+
max_step_saves_to_keep: 4,
|
| 51 |
+
save_format: 'diffusers',
|
| 52 |
+
push_to_hub: false,
|
| 53 |
+
},
|
| 54 |
+
datasets: [defaultDatasetConfig],
|
| 55 |
+
train: {
|
| 56 |
+
batch_size: 1,
|
| 57 |
+
bypass_guidance_embedding: true,
|
| 58 |
+
steps: 3000,
|
| 59 |
+
gradient_accumulation: 1,
|
| 60 |
+
train_unet: true,
|
| 61 |
+
train_text_encoder: false,
|
| 62 |
+
gradient_checkpointing: true,
|
| 63 |
+
noise_scheduler: 'flowmatch',
|
| 64 |
+
optimizer: 'adamw8bit',
|
| 65 |
+
timestep_type: 'sigmoid',
|
| 66 |
+
content_or_style: 'balanced',
|
| 67 |
+
optimizer_params: {
|
| 68 |
+
weight_decay: 1e-4,
|
| 69 |
+
},
|
| 70 |
+
unload_text_encoder: false,
|
| 71 |
+
cache_text_embeddings: false,
|
| 72 |
+
lr: 0.0001,
|
| 73 |
+
ema_config: {
|
| 74 |
+
use_ema: false,
|
| 75 |
+
ema_decay: 0.99,
|
| 76 |
+
},
|
| 77 |
+
skip_first_sample: false,
|
| 78 |
+
disable_sampling: false,
|
| 79 |
+
dtype: 'bf16',
|
| 80 |
+
diff_output_preservation: false,
|
| 81 |
+
diff_output_preservation_multiplier: 1.0,
|
| 82 |
+
diff_output_preservation_class: 'person',
|
| 83 |
+
switch_boundary_every: 1,
|
| 84 |
+
},
|
| 85 |
+
model: {
|
| 86 |
+
name_or_path: 'ostris/Flex.1-alpha',
|
| 87 |
+
quantize: true,
|
| 88 |
+
qtype: 'qfloat8',
|
| 89 |
+
quantize_te: true,
|
| 90 |
+
qtype_te: 'qfloat8',
|
| 91 |
+
arch: 'flex1',
|
| 92 |
+
low_vram: false,
|
| 93 |
+
model_kwargs: {},
|
| 94 |
+
},
|
| 95 |
+
sample: {
|
| 96 |
+
sampler: 'flowmatch',
|
| 97 |
+
sample_every: 250,
|
| 98 |
+
width: 1024,
|
| 99 |
+
height: 1024,
|
| 100 |
+
samples: [
|
| 101 |
+
{
|
| 102 |
+
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background'
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe',
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
prompt: 'a bear building a log cabin in the snow covered mountains',
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
prompt: 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
prompt: 'hipster man with a beard, building a chair, in a wood shop',
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
prompt: "a man holding a sign that says, 'this is a sign'",
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
|
| 130 |
+
},
|
| 131 |
+
],
|
| 132 |
+
neg: '',
|
| 133 |
+
seed: 42,
|
| 134 |
+
walk_seed: true,
|
| 135 |
+
guidance_scale: 4,
|
| 136 |
+
sample_steps: 25,
|
| 137 |
+
num_frames: 1,
|
| 138 |
+
fps: 1,
|
| 139 |
+
},
|
| 140 |
+
},
|
| 141 |
+
],
|
| 142 |
+
},
|
| 143 |
+
meta: {
|
| 144 |
+
name: '[name]',
|
| 145 |
+
version: '1.0',
|
| 146 |
+
},
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
|
| 150 |
+
// upgrade prompt strings to samples
|
| 151 |
+
if (
|
| 152 |
+
jobConfig?.config?.process &&
|
| 153 |
+
jobConfig.config.process[0]?.sample &&
|
| 154 |
+
Array.isArray(jobConfig.config.process[0].sample.prompts) &&
|
| 155 |
+
jobConfig.config.process[0].sample.prompts.length > 0
|
| 156 |
+
) {
|
| 157 |
+
let newSamples = [];
|
| 158 |
+
for (const prompt of jobConfig.config.process[0].sample.prompts) {
|
| 159 |
+
newSamples.push({
|
| 160 |
+
prompt: prompt,
|
| 161 |
+
});
|
| 162 |
+
}
|
| 163 |
+
jobConfig.config.process[0].sample.samples = newSamples;
|
| 164 |
+
delete jobConfig.config.process[0].sample.prompts;
|
| 165 |
+
}
|
| 166 |
+
return jobConfig;
|
| 167 |
+
};
|
src/app/jobs/new/options.ts
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { GroupedSelectOption, SelectOption } from '@/types';
|
| 2 |
+
|
| 3 |
+
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
| 4 |
+
|
| 5 |
+
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
| 6 |
+
type AdditionalSections =
|
| 7 |
+
| 'datasets.control_path'
|
| 8 |
+
| 'datasets.do_i2v'
|
| 9 |
+
| 'sample.ctrl_img'
|
| 10 |
+
| 'datasets.num_frames'
|
| 11 |
+
| 'model.multistage'
|
| 12 |
+
| 'model.low_vram';
|
| 13 |
+
type ModelGroup = 'image' | 'instruction' | 'video';
|
| 14 |
+
|
| 15 |
+
export interface ModelArch {
|
| 16 |
+
name: string;
|
| 17 |
+
label: string;
|
| 18 |
+
group: ModelGroup;
|
| 19 |
+
controls?: Control[];
|
| 20 |
+
isVideoModel?: boolean;
|
| 21 |
+
defaults?: { [key: string]: any };
|
| 22 |
+
disableSections?: DisableableSections[];
|
| 23 |
+
additionalSections?: AdditionalSections[];
|
| 24 |
+
accuracyRecoveryAdapters?: { [key: string]: string };
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
const defaultNameOrPath = '';
|
| 28 |
+
|
| 29 |
+
export const modelArchs: ModelArch[] = [
|
| 30 |
+
{
|
| 31 |
+
name: 'flux',
|
| 32 |
+
label: 'FLUX.1',
|
| 33 |
+
group: 'image',
|
| 34 |
+
defaults: {
|
| 35 |
+
// default updates when [selected, unselected] in the UI
|
| 36 |
+
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath],
|
| 37 |
+
'config.process[0].model.quantize': [true, false],
|
| 38 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 39 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 40 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 41 |
+
},
|
| 42 |
+
disableSections: ['network.conv'],
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
name: 'flux_kontext',
|
| 46 |
+
label: 'FLUX.1-Kontext-dev',
|
| 47 |
+
group: 'instruction',
|
| 48 |
+
defaults: {
|
| 49 |
+
// default updates when [selected, unselected] in the UI
|
| 50 |
+
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
|
| 51 |
+
'config.process[0].model.quantize': [true, false],
|
| 52 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 53 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 54 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 55 |
+
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
| 56 |
+
},
|
| 57 |
+
disableSections: ['network.conv'],
|
| 58 |
+
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
name: 'flex1',
|
| 62 |
+
label: 'Flex.1',
|
| 63 |
+
group: 'image',
|
| 64 |
+
defaults: {
|
| 65 |
+
// default updates when [selected, unselected] in the UI
|
| 66 |
+
'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath],
|
| 67 |
+
'config.process[0].model.quantize': [true, false],
|
| 68 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 69 |
+
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
| 70 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 71 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 72 |
+
},
|
| 73 |
+
disableSections: ['network.conv'],
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
name: 'flex2',
|
| 77 |
+
label: 'Flex.2',
|
| 78 |
+
group: 'image',
|
| 79 |
+
controls: ['depth', 'line', 'pose', 'inpaint'],
|
| 80 |
+
defaults: {
|
| 81 |
+
// default updates when [selected, unselected] in the UI
|
| 82 |
+
'config.process[0].model.name_or_path': ['ostris/Flex.2-preview', defaultNameOrPath],
|
| 83 |
+
'config.process[0].model.quantize': [true, false],
|
| 84 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 85 |
+
'config.process[0].model.model_kwargs': [
|
| 86 |
+
{
|
| 87 |
+
invert_inpaint_mask_chance: 0.2,
|
| 88 |
+
inpaint_dropout: 0.5,
|
| 89 |
+
control_dropout: 0.5,
|
| 90 |
+
inpaint_random_chance: 0.2,
|
| 91 |
+
do_random_inpainting: true,
|
| 92 |
+
random_blur_mask: true,
|
| 93 |
+
random_dialate_mask: true,
|
| 94 |
+
},
|
| 95 |
+
{},
|
| 96 |
+
],
|
| 97 |
+
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
| 98 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 99 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 100 |
+
},
|
| 101 |
+
disableSections: ['network.conv'],
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
name: 'chroma',
|
| 105 |
+
label: 'Chroma',
|
| 106 |
+
group: 'image',
|
| 107 |
+
defaults: {
|
| 108 |
+
// default updates when [selected, unselected] in the UI
|
| 109 |
+
'config.process[0].model.name_or_path': ['lodestones/Chroma1-Base', defaultNameOrPath],
|
| 110 |
+
'config.process[0].model.quantize': [true, false],
|
| 111 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 112 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 113 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 114 |
+
},
|
| 115 |
+
disableSections: ['network.conv'],
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
name: 'wan21:1b',
|
| 119 |
+
label: 'Wan 2.1 (1.3B)',
|
| 120 |
+
group: 'video',
|
| 121 |
+
isVideoModel: true,
|
| 122 |
+
defaults: {
|
| 123 |
+
// default updates when [selected, unselected] in the UI
|
| 124 |
+
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-1.3B-Diffusers', defaultNameOrPath],
|
| 125 |
+
'config.process[0].model.quantize': [false, false],
|
| 126 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 127 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 128 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 129 |
+
'config.process[0].sample.num_frames': [41, 1],
|
| 130 |
+
'config.process[0].sample.fps': [16, 1],
|
| 131 |
+
},
|
| 132 |
+
disableSections: ['network.conv'],
|
| 133 |
+
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
name: 'wan21_i2v:14b480p',
|
| 137 |
+
label: 'Wan 2.1 I2V (14B-480P)',
|
| 138 |
+
group: 'video',
|
| 139 |
+
isVideoModel: true,
|
| 140 |
+
defaults: {
|
| 141 |
+
// default updates when [selected, unselected] in the UI
|
| 142 |
+
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath],
|
| 143 |
+
'config.process[0].model.quantize': [true, false],
|
| 144 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 145 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 146 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 147 |
+
'config.process[0].sample.num_frames': [41, 1],
|
| 148 |
+
'config.process[0].sample.fps': [16, 1],
|
| 149 |
+
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
| 150 |
+
},
|
| 151 |
+
disableSections: ['network.conv'],
|
| 152 |
+
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
name: 'wan21_i2v:14b',
|
| 156 |
+
label: 'Wan 2.1 I2V (14B-720P)',
|
| 157 |
+
group: 'video',
|
| 158 |
+
isVideoModel: true,
|
| 159 |
+
defaults: {
|
| 160 |
+
// default updates when [selected, unselected] in the UI
|
| 161 |
+
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath],
|
| 162 |
+
'config.process[0].model.quantize': [true, false],
|
| 163 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 164 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 165 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 166 |
+
'config.process[0].sample.num_frames': [41, 1],
|
| 167 |
+
'config.process[0].sample.fps': [16, 1],
|
| 168 |
+
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
| 169 |
+
},
|
| 170 |
+
disableSections: ['network.conv'],
|
| 171 |
+
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
name: 'wan21:14b',
|
| 175 |
+
label: 'Wan 2.1 (14B)',
|
| 176 |
+
group: 'video',
|
| 177 |
+
isVideoModel: true,
|
| 178 |
+
defaults: {
|
| 179 |
+
// default updates when [selected, unselected] in the UI
|
| 180 |
+
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffusers', defaultNameOrPath],
|
| 181 |
+
'config.process[0].model.quantize': [true, false],
|
| 182 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 183 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 184 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 185 |
+
'config.process[0].sample.num_frames': [41, 1],
|
| 186 |
+
'config.process[0].sample.fps': [16, 1],
|
| 187 |
+
},
|
| 188 |
+
disableSections: ['network.conv'],
|
| 189 |
+
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
name: 'wan22_14b:t2v',
|
| 193 |
+
label: 'Wan 2.2 (14B)',
|
| 194 |
+
group: 'video',
|
| 195 |
+
isVideoModel: true,
|
| 196 |
+
defaults: {
|
| 197 |
+
// default updates when [selected, unselected] in the UI
|
| 198 |
+
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath],
|
| 199 |
+
'config.process[0].model.quantize': [true, false],
|
| 200 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 201 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 202 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 203 |
+
'config.process[0].sample.num_frames': [41, 1],
|
| 204 |
+
'config.process[0].sample.fps': [16, 1],
|
| 205 |
+
'config.process[0].model.low_vram': [true, false],
|
| 206 |
+
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
|
| 207 |
+
'config.process[0].model.model_kwargs': [
|
| 208 |
+
{
|
| 209 |
+
train_high_noise: true,
|
| 210 |
+
train_low_noise: true,
|
| 211 |
+
},
|
| 212 |
+
{},
|
| 213 |
+
],
|
| 214 |
+
},
|
| 215 |
+
disableSections: ['network.conv'],
|
| 216 |
+
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
|
| 217 |
+
accuracyRecoveryAdapters: {
|
| 218 |
+
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
|
| 219 |
+
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors',
|
| 220 |
+
},
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
name: 'wan22_14b_i2v',
|
| 224 |
+
label: 'Wan 2.2 I2V (14B)',
|
| 225 |
+
group: 'video',
|
| 226 |
+
isVideoModel: true,
|
| 227 |
+
defaults: {
|
| 228 |
+
// default updates when [selected, unselected] in the UI
|
| 229 |
+
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-I2V-A14B-Diffusers-bf16', defaultNameOrPath],
|
| 230 |
+
'config.process[0].model.quantize': [true, false],
|
| 231 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 232 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 233 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 234 |
+
'config.process[0].sample.num_frames': [41, 1],
|
| 235 |
+
'config.process[0].sample.fps': [16, 1],
|
| 236 |
+
'config.process[0].model.low_vram': [true, false],
|
| 237 |
+
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
|
| 238 |
+
'config.process[0].model.model_kwargs': [
|
| 239 |
+
{
|
| 240 |
+
train_high_noise: true,
|
| 241 |
+
train_low_noise: true,
|
| 242 |
+
},
|
| 243 |
+
{},
|
| 244 |
+
],
|
| 245 |
+
},
|
| 246 |
+
disableSections: ['network.conv'],
|
| 247 |
+
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'],
|
| 248 |
+
accuracyRecoveryAdapters: {
|
| 249 |
+
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors',
|
| 250 |
+
},
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
name: 'wan22_5b',
|
| 254 |
+
label: 'Wan 2.2 TI2V (5B)',
|
| 255 |
+
group: 'video',
|
| 256 |
+
isVideoModel: true,
|
| 257 |
+
defaults: {
|
| 258 |
+
// default updates when [selected, unselected] in the UI
|
| 259 |
+
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.2-TI2V-5B-Diffusers', defaultNameOrPath],
|
| 260 |
+
'config.process[0].model.quantize': [true, false],
|
| 261 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 262 |
+
'config.process[0].model.low_vram': [true, false],
|
| 263 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 264 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 265 |
+
'config.process[0].sample.num_frames': [121, 1],
|
| 266 |
+
'config.process[0].sample.fps': [24, 1],
|
| 267 |
+
'config.process[0].sample.width': [768, 1024],
|
| 268 |
+
'config.process[0].sample.height': [768, 1024],
|
| 269 |
+
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
| 270 |
+
},
|
| 271 |
+
disableSections: ['network.conv'],
|
| 272 |
+
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
name: 'lumina2',
|
| 276 |
+
label: 'Lumina2',
|
| 277 |
+
group: 'image',
|
| 278 |
+
defaults: {
|
| 279 |
+
// default updates when [selected, unselected] in the UI
|
| 280 |
+
'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath],
|
| 281 |
+
'config.process[0].model.quantize': [false, false],
|
| 282 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 283 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 284 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 285 |
+
},
|
| 286 |
+
disableSections: ['network.conv'],
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
name: 'qwen_image',
|
| 290 |
+
label: 'Qwen-Image',
|
| 291 |
+
group: 'image',
|
| 292 |
+
defaults: {
|
| 293 |
+
// default updates when [selected, unselected] in the UI
|
| 294 |
+
'config.process[0].model.name_or_path': ['Qwen/Qwen-Image', defaultNameOrPath],
|
| 295 |
+
'config.process[0].model.quantize': [true, false],
|
| 296 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 297 |
+
'config.process[0].model.low_vram': [true, false],
|
| 298 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 299 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 300 |
+
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
| 301 |
+
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
| 302 |
+
},
|
| 303 |
+
disableSections: ['network.conv'],
|
| 304 |
+
additionalSections: ['model.low_vram'],
|
| 305 |
+
accuracyRecoveryAdapters: {
|
| 306 |
+
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_torchao_uint3.safetensors',
|
| 307 |
+
},
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
name: 'qwen_image_edit',
|
| 311 |
+
label: 'Qwen-Image-Edit',
|
| 312 |
+
group: 'instruction',
|
| 313 |
+
defaults: {
|
| 314 |
+
// default updates when [selected, unselected] in the UI
|
| 315 |
+
'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit', defaultNameOrPath],
|
| 316 |
+
'config.process[0].model.quantize': [true, false],
|
| 317 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 318 |
+
'config.process[0].model.low_vram': [true, false],
|
| 319 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 320 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 321 |
+
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
| 322 |
+
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
| 323 |
+
},
|
| 324 |
+
disableSections: ['network.conv'],
|
| 325 |
+
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'],
|
| 326 |
+
accuracyRecoveryAdapters: {
|
| 327 |
+
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
|
| 328 |
+
},
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
name: 'hidream',
|
| 332 |
+
label: 'HiDream',
|
| 333 |
+
group: 'image',
|
| 334 |
+
defaults: {
|
| 335 |
+
// default updates when [selected, unselected] in the UI
|
| 336 |
+
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath],
|
| 337 |
+
'config.process[0].model.quantize': [true, false],
|
| 338 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 339 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 340 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 341 |
+
'config.process[0].train.lr': [0.0002, 0.0001],
|
| 342 |
+
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
|
| 343 |
+
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
|
| 344 |
+
},
|
| 345 |
+
disableSections: ['network.conv'],
|
| 346 |
+
additionalSections: ['model.low_vram'],
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
name: 'hidream_e1',
|
| 350 |
+
label: 'HiDream E1',
|
| 351 |
+
group: 'instruction',
|
| 352 |
+
defaults: {
|
| 353 |
+
// default updates when [selected, unselected] in the UI
|
| 354 |
+
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-E1-1', defaultNameOrPath],
|
| 355 |
+
'config.process[0].model.quantize': [true, false],
|
| 356 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 357 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 358 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 359 |
+
'config.process[0].train.lr': [0.0001, 0.0001],
|
| 360 |
+
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
| 361 |
+
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
|
| 362 |
+
},
|
| 363 |
+
disableSections: ['network.conv'],
|
| 364 |
+
additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'],
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
name: 'sdxl',
|
| 368 |
+
label: 'SDXL',
|
| 369 |
+
group: 'image',
|
| 370 |
+
defaults: {
|
| 371 |
+
// default updates when [selected, unselected] in the UI
|
| 372 |
+
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
|
| 373 |
+
'config.process[0].model.quantize': [false, false],
|
| 374 |
+
'config.process[0].model.quantize_te': [false, false],
|
| 375 |
+
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
| 376 |
+
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
| 377 |
+
'config.process[0].sample.guidance_scale': [6, 4],
|
| 378 |
+
},
|
| 379 |
+
disableSections: ['model.quantize', 'train.timestep_type'],
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
name: 'sd15',
|
| 383 |
+
label: 'SD 1.5',
|
| 384 |
+
group: 'image',
|
| 385 |
+
defaults: {
|
| 386 |
+
// default updates when [selected, unselected] in the UI
|
| 387 |
+
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
|
| 388 |
+
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
| 389 |
+
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
| 390 |
+
'config.process[0].sample.width': [512, 1024],
|
| 391 |
+
'config.process[0].sample.height': [512, 1024],
|
| 392 |
+
'config.process[0].sample.guidance_scale': [6, 4],
|
| 393 |
+
},
|
| 394 |
+
disableSections: ['model.quantize', 'train.timestep_type'],
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
name: 'omnigen2',
|
| 398 |
+
label: 'OmniGen2',
|
| 399 |
+
group: 'image',
|
| 400 |
+
defaults: {
|
| 401 |
+
// default updates when [selected, unselected] in the UI
|
| 402 |
+
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
|
| 403 |
+
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
| 404 |
+
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
| 405 |
+
'config.process[0].model.quantize': [false, false],
|
| 406 |
+
'config.process[0].model.quantize_te': [true, false],
|
| 407 |
+
},
|
| 408 |
+
disableSections: ['network.conv'],
|
| 409 |
+
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
|
| 410 |
+
},
|
| 411 |
+
].sort((a, b) => {
|
| 412 |
+
// Sort by label, case-insensitive
|
| 413 |
+
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
|
| 414 |
+
}) as any;
|
| 415 |
+
|
| 416 |
+
export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
|
| 417 |
+
const group = acc.find(g => g.label === arch.group);
|
| 418 |
+
if (group) {
|
| 419 |
+
group.options.push({ value: arch.name, label: arch.label });
|
| 420 |
+
} else {
|
| 421 |
+
acc.push({
|
| 422 |
+
label: arch.group,
|
| 423 |
+
options: [{ value: arch.name, label: arch.label }],
|
| 424 |
+
});
|
| 425 |
+
}
|
| 426 |
+
return acc;
|
| 427 |
+
}, [] as GroupedSelectOption[]);
|
| 428 |
+
|
| 429 |
+
export const quantizationOptions: SelectOption[] = [
|
| 430 |
+
{ value: '', label: '- NONE -' },
|
| 431 |
+
{ value: 'qfloat8', label: 'float8 (default)' },
|
| 432 |
+
{ value: 'uint8', label: '8 bit' },
|
| 433 |
+
{ value: 'uint7', label: '7 bit' },
|
| 434 |
+
{ value: 'uint6', label: '6 bit' },
|
| 435 |
+
{ value: 'uint5', label: '5 bit' },
|
| 436 |
+
{ value: 'uint4', label: '4 bit' },
|
| 437 |
+
{ value: 'uint3', label: '3 bit' },
|
| 438 |
+
{ value: 'uint2', label: '2 bit' },
|
| 439 |
+
];
|
| 440 |
+
|
| 441 |
+
export const defaultQtype = 'qfloat8';
|
src/app/jobs/new/page.tsx
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import { useEffect, useState } from 'react';
|
| 4 |
+
import { useSearchParams, useRouter } from 'next/navigation';
|
| 5 |
+
import Link from 'next/link';
|
| 6 |
+
import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig';
|
| 7 |
+
import { JobConfig } from '@/types';
|
| 8 |
+
import { objectCopy } from '@/utils/basic';
|
| 9 |
+
import { useNestedState } from '@/utils/hooks';
|
| 10 |
+
import { SelectInput } from '@/components/formInputs';
|
| 11 |
+
import useSettings from '@/hooks/useSettings';
|
| 12 |
+
import useGPUInfo from '@/hooks/useGPUInfo';
|
| 13 |
+
import useDatasetList from '@/hooks/useDatasetList';
|
| 14 |
+
import path from 'path';
|
| 15 |
+
import { TopBar, MainContent } from '@/components/layout';
|
| 16 |
+
import { Button } from '@headlessui/react';
|
| 17 |
+
import { FaChevronLeft } from 'react-icons/fa';
|
| 18 |
+
import SimpleJob from './SimpleJob';
|
| 19 |
+
import AdvancedJob from './AdvancedJob';
|
| 20 |
+
import ErrorBoundary from '@/components/ErrorBoundary';
|
| 21 |
+
import { getJob, upsertJob } from '@/utils/storage/jobStorage';
|
| 22 |
+
import { usingBrowserDb } from '@/utils/env';
|
| 23 |
+
import { getUserDatasetPath, updateUserDatasetPath } from '@/utils/storage/datasetStorage';
|
| 24 |
+
import { apiClient } from '@/utils/api';
|
| 25 |
+
import { useAuth } from '@/contexts/AuthContext';
|
| 26 |
+
import HFLoginButton from '@/components/HFLoginButton';
|
| 27 |
+
|
| 28 |
+
const isDev = process.env.NODE_ENV === 'development';
|
| 29 |
+
|
| 30 |
+
export default function TrainingForm() {
|
| 31 |
+
const router = useRouter();
|
| 32 |
+
const searchParams = useSearchParams();
|
| 33 |
+
const runId = searchParams.get('id');
|
| 34 |
+
const { status: authStatus } = useAuth();
|
| 35 |
+
const isAuthenticated = authStatus === 'authenticated';
|
| 36 |
+
const [gpuIDs, setGpuIDs] = useState<string | null>(null);
|
| 37 |
+
const { settings, isSettingsLoaded } = useSettings();
|
| 38 |
+
const { gpuList, isGPUInfoLoaded } = useGPUInfo();
|
| 39 |
+
const { datasets, status: datasetFetchStatus } = useDatasetList();
|
| 40 |
+
const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
|
| 41 |
+
const [showAdvancedView, setShowAdvancedView] = useState(false);
|
| 42 |
+
|
| 43 |
+
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
|
| 44 |
+
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
| 45 |
+
|
| 46 |
+
// Track HF Jobs backend state
|
| 47 |
+
const [trainingBackend, setTrainingBackend] = useState<'local' | 'hf-jobs'>(
|
| 48 |
+
usingBrowserDb ? 'hf-jobs' : 'local',
|
| 49 |
+
);
|
| 50 |
+
const [hfJobSubmitted, setHfJobSubmitted] = useState(false);
|
| 51 |
+
|
| 52 |
+
useEffect(() => {
|
| 53 |
+
if (!isSettingsLoaded || !isAuthenticated) return;
|
| 54 |
+
if (datasetFetchStatus !== 'success') return;
|
| 55 |
+
|
| 56 |
+
let isMounted = true;
|
| 57 |
+
|
| 58 |
+
const buildDatasetOptions = async () => {
|
| 59 |
+
const options = await Promise.all(
|
| 60 |
+
datasets.map(async name => {
|
| 61 |
+
let datasetPath = settings.DATASETS_FOLDER ? path.join(settings.DATASETS_FOLDER, name) : '';
|
| 62 |
+
|
| 63 |
+
if (usingBrowserDb) {
|
| 64 |
+
const storedPath = getUserDatasetPath(name);
|
| 65 |
+
if (storedPath) {
|
| 66 |
+
datasetPath = storedPath;
|
| 67 |
+
} else {
|
| 68 |
+
try {
|
| 69 |
+
const response = await apiClient
|
| 70 |
+
.post('/api/datasets/create', { name })
|
| 71 |
+
.then(res => res.data);
|
| 72 |
+
if (response?.path) {
|
| 73 |
+
datasetPath = response.path;
|
| 74 |
+
updateUserDatasetPath(name, datasetPath);
|
| 75 |
+
}
|
| 76 |
+
} catch (err) {
|
| 77 |
+
console.error('Error resolving dataset path:', err);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if (!datasetPath) {
|
| 83 |
+
datasetPath = name;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return { value: datasetPath, label: name };
|
| 87 |
+
}),
|
| 88 |
+
);
|
| 89 |
+
|
| 90 |
+
if (!isMounted) {
|
| 91 |
+
return;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
setDatasetOptions(options);
|
| 95 |
+
const defaultDatasetPath = defaultDatasetConfig.folder_path;
|
| 96 |
+
|
| 97 |
+
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
|
| 98 |
+
const dataset = jobConfig.config.process[0].datasets[i];
|
| 99 |
+
if (dataset.folder_path === defaultDatasetPath) {
|
| 100 |
+
if (options.length > 0) {
|
| 101 |
+
setJobConfig(options[0].value, `config.process[0].datasets[${i}].folder_path`);
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
buildDatasetOptions();
|
| 108 |
+
|
| 109 |
+
return () => {
|
| 110 |
+
isMounted = false;
|
| 111 |
+
};
|
| 112 |
+
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
|
| 113 |
+
|
| 114 |
+
useEffect(() => {
|
| 115 |
+
if (runId) {
|
| 116 |
+
getJob(runId)
|
| 117 |
+
.then(data => {
|
| 118 |
+
if (!data) {
|
| 119 |
+
throw new Error('Job not found');
|
| 120 |
+
}
|
| 121 |
+
setGpuIDs(data.gpu_ids);
|
| 122 |
+
const parsedJobConfig = migrateJobConfig(JSON.parse(data.job_config));
|
| 123 |
+
setJobConfig(parsedJobConfig);
|
| 124 |
+
|
| 125 |
+
if (parsedJobConfig.is_hf_job) {
|
| 126 |
+
setTrainingBackend('hf-jobs');
|
| 127 |
+
setHfJobSubmitted(true);
|
| 128 |
+
}
|
| 129 |
+
})
|
| 130 |
+
.catch(error => console.error('Error fetching training:', error));
|
| 131 |
+
}
|
| 132 |
+
}, [runId]);
|
| 133 |
+
|
| 134 |
+
useEffect(() => {
|
| 135 |
+
if (isGPUInfoLoaded) {
|
| 136 |
+
if (gpuIDs === null && gpuList.length > 0) {
|
| 137 |
+
setGpuIDs(`${gpuList[0].index}`);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
}, [gpuList, isGPUInfoLoaded]);
|
| 141 |
+
|
| 142 |
+
useEffect(() => {
|
| 143 |
+
if (isSettingsLoaded) {
|
| 144 |
+
setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder');
|
| 145 |
+
}
|
| 146 |
+
}, [settings, isSettingsLoaded]);
|
| 147 |
+
|
| 148 |
+
const saveJob = async () => {
|
| 149 |
+
if (!isAuthenticated) return;
|
| 150 |
+
if (status === 'saving') return;
|
| 151 |
+
setStatus('saving');
|
| 152 |
+
|
| 153 |
+
try {
|
| 154 |
+
const savedJob = await upsertJob({
|
| 155 |
+
id: runId || undefined,
|
| 156 |
+
name: jobConfig.config.name,
|
| 157 |
+
gpu_ids: gpuIDs,
|
| 158 |
+
job_config: {
|
| 159 |
+
...jobConfig,
|
| 160 |
+
is_hf_job: trainingBackend === 'hf-jobs',
|
| 161 |
+
hf_job_submitted: hfJobSubmitted,
|
| 162 |
+
training_backend: trainingBackend,
|
| 163 |
+
},
|
| 164 |
+
status: trainingBackend === 'hf-jobs' ? (hfJobSubmitted ? 'submitted' : 'stopped') : undefined,
|
| 165 |
+
});
|
| 166 |
+
|
| 167 |
+
setStatus('success');
|
| 168 |
+
router.push(`/jobs/${savedJob.id}`);
|
| 169 |
+
} catch (error: any) {
|
| 170 |
+
console.log('Error saving training:', error);
|
| 171 |
+
if (error?.code === 'P2002') {
|
| 172 |
+
alert('Training name already exists. Please choose a different name.');
|
| 173 |
+
} else {
|
| 174 |
+
alert('Failed to save job. Please try again.');
|
| 175 |
+
}
|
| 176 |
+
} finally {
|
| 177 |
+
setTimeout(() => {
|
| 178 |
+
setStatus('idle');
|
| 179 |
+
}, 2000);
|
| 180 |
+
}
|
| 181 |
+
};
|
| 182 |
+
|
| 183 |
+
const handleSubmit = async (e: React.FormEvent) => {
|
| 184 |
+
e.preventDefault();
|
| 185 |
+
saveJob();
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
return (
|
| 189 |
+
<>
|
| 190 |
+
<TopBar>
|
| 191 |
+
<div>
|
| 192 |
+
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
| 193 |
+
<FaChevronLeft />
|
| 194 |
+
</Button>
|
| 195 |
+
</div>
|
| 196 |
+
<div>
|
| 197 |
+
<h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
|
| 198 |
+
</div>
|
| 199 |
+
<div className="flex-1"></div>
|
| 200 |
+
{showAdvancedView && isAuthenticated && (
|
| 201 |
+
<>
|
| 202 |
+
<div>
|
| 203 |
+
<SelectInput
|
| 204 |
+
value={`${gpuIDs}`}
|
| 205 |
+
onChange={value => setGpuIDs(value)}
|
| 206 |
+
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
| 207 |
+
/>
|
| 208 |
+
</div>
|
| 209 |
+
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
|
| 210 |
+
</>
|
| 211 |
+
)}
|
| 212 |
+
|
| 213 |
+
<div className="pr-2">
|
| 214 |
+
<Button
|
| 215 |
+
className="text-gray-200 bg-gray-800 px-3 py-1 rounded-md"
|
| 216 |
+
onClick={() => setShowAdvancedView(!showAdvancedView)}
|
| 217 |
+
>
|
| 218 |
+
{showAdvancedView ? 'Show Simple' : 'Show Advanced'}
|
| 219 |
+
</Button>
|
| 220 |
+
</div>
|
| 221 |
+
<div>
|
| 222 |
+
<Button
|
| 223 |
+
className="text-gray-200 bg-green-800 hover:bg-green-700 px-3 py-1 rounded-md"
|
| 224 |
+
onClick={() => saveJob()}
|
| 225 |
+
disabled={!isAuthenticated || status === 'saving'}
|
| 226 |
+
>
|
| 227 |
+
{status === 'saving'
|
| 228 |
+
? 'Saving...'
|
| 229 |
+
: runId
|
| 230 |
+
? 'Update Job'
|
| 231 |
+
: 'Create Job'}
|
| 232 |
+
</Button>
|
| 233 |
+
</div>
|
| 234 |
+
</TopBar>
|
| 235 |
+
|
| 236 |
+
{!isAuthenticated ? (
|
| 237 |
+
<MainContent>
|
| 238 |
+
<div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
|
| 239 |
+
<p>You need to sign in with Hugging Face or provide a valid access token before creating or editing jobs.</p>
|
| 240 |
+
<div className="flex items-center gap-3">
|
| 241 |
+
<HFLoginButton size="sm" />
|
| 242 |
+
<Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
|
| 243 |
+
Manage authentication in Settings
|
| 244 |
+
</Link>
|
| 245 |
+
</div>
|
| 246 |
+
</div>
|
| 247 |
+
</MainContent>
|
| 248 |
+
) : showAdvancedView ? (
|
| 249 |
+
<div className="pt-[48px] absolute top-0 left-0 w-full h-full overflow-auto">
|
| 250 |
+
<AdvancedJob
|
| 251 |
+
jobConfig={jobConfig}
|
| 252 |
+
setJobConfig={setJobConfig}
|
| 253 |
+
status={status}
|
| 254 |
+
handleSubmit={handleSubmit}
|
| 255 |
+
runId={runId}
|
| 256 |
+
gpuIDs={gpuIDs}
|
| 257 |
+
setGpuIDs={setGpuIDs}
|
| 258 |
+
gpuList={gpuList}
|
| 259 |
+
datasetOptions={datasetOptions}
|
| 260 |
+
settings={settings}
|
| 261 |
+
/>
|
| 262 |
+
</div>
|
| 263 |
+
) : (
|
| 264 |
+
<MainContent>
|
| 265 |
+
<ErrorBoundary
|
| 266 |
+
fallback={
|
| 267 |
+
<div className="flex items-center justify-center h-64 text-lg text-red-600 font-medium bg-red-100 dark:bg-red-900/20 dark:text-red-400 border border-red-300 dark:border-red-700 rounded-lg">
|
| 268 |
+
Advanced job detected. Please switch to advanced view to continue.
|
| 269 |
+
</div>
|
| 270 |
+
}
|
| 271 |
+
>
|
| 272 |
+
<SimpleJob
|
| 273 |
+
jobConfig={jobConfig}
|
| 274 |
+
setJobConfig={setJobConfig}
|
| 275 |
+
status={status}
|
| 276 |
+
handleSubmit={handleSubmit}
|
| 277 |
+
runId={runId}
|
| 278 |
+
gpuIDs={gpuIDs}
|
| 279 |
+
setGpuIDs={setGpuIDs}
|
| 280 |
+
gpuList={gpuList}
|
| 281 |
+
datasetOptions={datasetOptions}
|
| 282 |
+
trainingBackend={trainingBackend}
|
| 283 |
+
setTrainingBackend={usingBrowserDb ? undefined : setTrainingBackend}
|
| 284 |
+
hfJobSubmitted={hfJobSubmitted}
|
| 285 |
+
onHFJobComplete={(jobId: string, localJobId?: string) => {
|
| 286 |
+
setHfJobSubmitted(true);
|
| 287 |
+
// Redirect to the job detail page
|
| 288 |
+
if (localJobId) {
|
| 289 |
+
router.push(`/jobs/${localJobId}`);
|
| 290 |
+
}
|
| 291 |
+
}}
|
| 292 |
+
forceHFBackend={usingBrowserDb}
|
| 293 |
+
/>
|
| 294 |
+
</ErrorBoundary>
|
| 295 |
+
|
| 296 |
+
<div className="pt-20"></div>
|
| 297 |
+
</MainContent>
|
| 298 |
+
)}
|
| 299 |
+
</>
|
| 300 |
+
);
|
| 301 |
+
}
|
| 302 |
+
useEffect(() => {
|
| 303 |
+
if (!isAuthenticated) {
|
| 304 |
+
setDatasetOptions([]);
|
| 305 |
+
}
|
| 306 |
+
}, [isAuthenticated]);
|
src/app/jobs/page.tsx
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import JobsTable from '@/components/JobsTable';
|
| 4 |
+
import { TopBar, MainContent } from '@/components/layout';
|
| 5 |
+
import Link from 'next/link';
|
| 6 |
+
import { useAuth } from '@/contexts/AuthContext';
|
| 7 |
+
import HFLoginButton from '@/components/HFLoginButton';
|
| 8 |
+
|
| 9 |
+
export default function Dashboard() {
|
| 10 |
+
const { status: authStatus } = useAuth();
|
| 11 |
+
const isAuthenticated = authStatus === 'authenticated';
|
| 12 |
+
|
| 13 |
+
return (
|
| 14 |
+
<>
|
| 15 |
+
<TopBar>
|
| 16 |
+
<div>
|
| 17 |
+
<h1 className="text-lg">Training Jobs</h1>
|
| 18 |
+
</div>
|
| 19 |
+
<div className="flex-1"></div>
|
| 20 |
+
<div>
|
| 21 |
+
{isAuthenticated ? (
|
| 22 |
+
<Link href="/jobs/new" className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md">
|
| 23 |
+
New Training Job
|
| 24 |
+
</Link>
|
| 25 |
+
) : (
|
| 26 |
+
<span className="text-gray-600 bg-gray-900 px-3 py-1 rounded-md border border-gray-800">
|
| 27 |
+
Sign in to create jobs
|
| 28 |
+
</span>
|
| 29 |
+
)}
|
| 30 |
+
</div>
|
| 31 |
+
</TopBar>
|
| 32 |
+
<MainContent>
|
| 33 |
+
{isAuthenticated ? (
|
| 34 |
+
<JobsTable />
|
| 35 |
+
) : (
|
| 36 |
+
<div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
|
| 37 |
+
<p>Sign in with Hugging Face or add a personal access token to view and manage training jobs.</p>
|
| 38 |
+
<div className="flex items-center gap-3">
|
| 39 |
+
<HFLoginButton size="sm" />
|
| 40 |
+
<Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
|
| 41 |
+
Manage tokens in Settings
|
| 42 |
+
</Link>
|
| 43 |
+
</div>
|
| 44 |
+
</div>
|
| 45 |
+
)}
|
| 46 |
+
</MainContent>
|
| 47 |
+
</>
|
| 48 |
+
);
|
| 49 |
+
}
|
src/app/layout.tsx
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { Metadata } from 'next';
|
| 2 |
+
import { Inter } from 'next/font/google';
|
| 3 |
+
import './globals.css';
|
| 4 |
+
import Sidebar from '@/components/Sidebar';
|
| 5 |
+
import { ThemeProvider } from '@/components/ThemeProvider';
|
| 6 |
+
import ConfirmModal from '@/components/ConfirmModal';
|
| 7 |
+
import SampleImageModal from '@/components/SampleImageModal';
|
| 8 |
+
import { Suspense } from 'react';
|
| 9 |
+
import AuthWrapper from '@/components/AuthWrapper';
|
| 10 |
+
import DocModal from '@/components/DocModal';
|
| 11 |
+
import { AuthProvider } from '@/contexts/AuthContext';
|
| 12 |
+
|
| 13 |
+
export const dynamic = 'force-dynamic';
|
| 14 |
+
|
| 15 |
+
const inter = Inter({ subsets: ['latin'] });
|
| 16 |
+
|
| 17 |
+
export const metadata: Metadata = {
|
| 18 |
+
title: 'Ostris - AI Toolkit',
|
| 19 |
+
description: 'A toolkit for building AI things.',
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
export default function RootLayout({ children }: { children: React.ReactNode }) {
|
| 23 |
+
// Check if the AI_TOOLKIT_AUTH environment variable is set
|
| 24 |
+
const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false;
|
| 25 |
+
|
| 26 |
+
return (
|
| 27 |
+
<html lang="en" className="dark">
|
| 28 |
+
<head>
|
| 29 |
+
<meta name="apple-mobile-web-app-title" content="AI-Toolkit" />
|
| 30 |
+
</head>
|
| 31 |
+
<body className={inter.className} suppressHydrationWarning={true}>
|
| 32 |
+
<ThemeProvider>
|
| 33 |
+
<AuthProvider>
|
| 34 |
+
<AuthWrapper authRequired={authRequired}>
|
| 35 |
+
<div className="flex h-screen bg-gray-950">
|
| 36 |
+
<Sidebar />
|
| 37 |
+
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
|
| 38 |
+
<Suspense>{children}</Suspense>
|
| 39 |
+
</main>
|
| 40 |
+
</div>
|
| 41 |
+
</AuthWrapper>
|
| 42 |
+
</AuthProvider>
|
| 43 |
+
</ThemeProvider>
|
| 44 |
+
<ConfirmModal />
|
| 45 |
+
<DocModal />
|
| 46 |
+
<SampleImageModal />
|
| 47 |
+
</body>
|
| 48 |
+
</html>
|
| 49 |
+
);
|
| 50 |
+
}
|
src/app/manifest.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "AI Toolkit",
|
| 3 |
+
"short_name": "AIToolkit",
|
| 4 |
+
"icons": [
|
| 5 |
+
{
|
| 6 |
+
"src": "/web-app-manifest-192x192.png",
|
| 7 |
+
"sizes": "192x192",
|
| 8 |
+
"type": "image/png",
|
| 9 |
+
"purpose": "maskable"
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"src": "/web-app-manifest-512x512.png",
|
| 13 |
+
"sizes": "512x512",
|
| 14 |
+
"type": "image/png",
|
| 15 |
+
"purpose": "maskable"
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"theme_color": "#000000",
|
| 19 |
+
"background_color": "#000000",
|
| 20 |
+
"display": "standalone"
|
| 21 |
+
}
|
src/app/page.tsx
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { redirect } from 'next/navigation';
|
| 2 |
+
|
| 3 |
+
export default function Home() {
|
| 4 |
+
redirect('/dashboard');
|
| 5 |
+
}
|
src/app/settings/page.tsx
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import { useEffect, useState } from 'react';
|
| 4 |
+
import useSettings from '@/hooks/useSettings';
|
| 5 |
+
import { TopBar, MainContent } from '@/components/layout';
|
| 6 |
+
import { persistSettings } from '@/utils/storage/settingsStorage';
|
| 7 |
+
import { useAuth } from '@/contexts/AuthContext';
|
| 8 |
+
import HFLoginButton from '@/components/HFLoginButton';
|
| 9 |
+
import { useMemo } from 'react';
|
| 10 |
+
import Link from 'next/link';
|
| 11 |
+
|
| 12 |
+
export default function Settings() {
|
| 13 |
+
const { settings, setSettings } = useSettings();
|
| 14 |
+
const { status: authStatus, namespace, oauthAvailable, loginWithOAuth, logout, setManualToken, error: authError, token: authToken } = useAuth();
|
| 15 |
+
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
| 16 |
+
const [manualToken, setManualTokenInput] = useState(settings.HF_TOKEN || '');
|
| 17 |
+
const isAuthenticated = authStatus === 'authenticated';
|
| 18 |
+
|
| 19 |
+
useEffect(() => {
|
| 20 |
+
setManualTokenInput(settings.HF_TOKEN || '');
|
| 21 |
+
}, [settings.HF_TOKEN]);
|
| 22 |
+
|
| 23 |
+
const handleSubmit = async (e: React.FormEvent) => {
|
| 24 |
+
e.preventDefault();
|
| 25 |
+
setStatus('saving');
|
| 26 |
+
|
| 27 |
+
persistSettings(settings)
|
| 28 |
+
.then(() => {
|
| 29 |
+
setStatus('success');
|
| 30 |
+
})
|
| 31 |
+
.catch(error => {
|
| 32 |
+
console.error('Error saving settings:', error);
|
| 33 |
+
setStatus('error');
|
| 34 |
+
})
|
| 35 |
+
.finally(() => {
|
| 36 |
+
setTimeout(() => setStatus('idle'), 2000);
|
| 37 |
+
});
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
| 41 |
+
const { name, value } = e.target;
|
| 42 |
+
setSettings(prev => ({ ...prev, [name]: value }));
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
const handleManualSubmit = async (e: React.FormEvent) => {
|
| 46 |
+
e.preventDefault();
|
| 47 |
+
await setManualToken(manualToken);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
const authDescription = useMemo(() => {
|
| 51 |
+
if (authStatus === 'checking') {
|
| 52 |
+
return 'Checking your Hugging Face session…';
|
| 53 |
+
}
|
| 54 |
+
if (isAuthenticated) {
|
| 55 |
+
return `Connected as ${namespace}`;
|
| 56 |
+
}
|
| 57 |
+
return 'Sign in to use Hugging Face Jobs or submit your own access token.';
|
| 58 |
+
}, [authStatus, isAuthenticated, namespace]);
|
| 59 |
+
|
| 60 |
+
return (
|
| 61 |
+
<>
|
| 62 |
+
<TopBar>
|
| 63 |
+
<div>
|
| 64 |
+
<h1 className="text-lg">Settings</h1>
|
| 65 |
+
</div>
|
| 66 |
+
<div className="flex-1"></div>
|
| 67 |
+
<div className="flex items-center gap-3 pr-2 text-sm text-gray-400">
|
| 68 |
+
{isAuthenticated ? (
|
| 69 |
+
<span>Welcome, {namespace || 'user'}</span>
|
| 70 |
+
) : (
|
| 71 |
+
<span>Authenticate to unlock training features</span>
|
| 72 |
+
)}
|
| 73 |
+
</div>
|
| 74 |
+
</TopBar>
|
| 75 |
+
<MainContent>
|
| 76 |
+
<div className="grid gap-4 md:grid-cols-2 mb-6">
|
| 77 |
+
<div className="border border-gray-800 rounded-xl p-5 bg-gray-900">
|
| 78 |
+
<div className="flex items-center justify-between mb-4">
|
| 79 |
+
<div>
|
| 80 |
+
<h2 className="text-md font-semibold text-gray-100">Sign in with Hugging Face</h2>
|
| 81 |
+
<p className="text-sm text-gray-400 mt-1">{authDescription}</p>
|
| 82 |
+
</div>
|
| 83 |
+
{isAuthenticated && (
|
| 84 |
+
<span className="text-xs px-2 py-1 rounded-full bg-emerald-900 text-emerald-300">Authenticated</span>
|
| 85 |
+
)}
|
| 86 |
+
</div>
|
| 87 |
+
<div className="flex items-center gap-3">
|
| 88 |
+
{isAuthenticated ? (
|
| 89 |
+
<button
|
| 90 |
+
type="button"
|
| 91 |
+
onClick={logout}
|
| 92 |
+
className="px-4 py-2 rounded-md border border-gray-700 text-sm bg-gray-800 hover:bg-gray-700 transition-colors"
|
| 93 |
+
>
|
| 94 |
+
Sign out
|
| 95 |
+
</button>
|
| 96 |
+
) : (
|
| 97 |
+
<>
|
| 98 |
+
<HFLoginButton size="md" className="bg-transparent border-none p-0" />
|
| 99 |
+
{!oauthAvailable && (
|
| 100 |
+
<span className="text-xs text-yellow-500">
|
| 101 |
+
OAuth is unavailable. Set HF_OAUTH_CLIENT_ID/SECRET on the server.
|
| 102 |
+
</span>
|
| 103 |
+
)}
|
| 104 |
+
</>
|
| 105 |
+
)}
|
| 106 |
+
</div>
|
| 107 |
+
{!isAuthenticated && authError && (
|
| 108 |
+
<p className="mt-3 text-xs text-red-400">{authError}</p>
|
| 109 |
+
)}
|
| 110 |
+
</div>
|
| 111 |
+
|
| 112 |
+
<form onSubmit={handleManualSubmit} className="border border-gray-800 rounded-xl p-5 bg-gray-900">
|
| 113 |
+
<h2 className="text-md font-semibold text-gray-100">Manual Token</h2>
|
| 114 |
+
<p className="text-sm text-gray-400 mt-1">
|
| 115 |
+
Paste an access token created at{' '}
|
| 116 |
+
<a href="https://huggingface.co/settings/tokens" target="_blank" rel="noreferrer" className="text-blue-400 hover:text-blue-300">
|
| 117 |
+
huggingface.co/settings/tokens
|
| 118 |
+
</a>
|
| 119 |
+
.
|
| 120 |
+
</p>
|
| 121 |
+
<div className="mt-4">
|
| 122 |
+
<input
|
| 123 |
+
type="password"
|
| 124 |
+
value={manualToken}
|
| 125 |
+
onChange={event => setManualTokenInput(event.target.value)}
|
| 126 |
+
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
| 127 |
+
placeholder="Enter Hugging Face token"
|
| 128 |
+
/>
|
| 129 |
+
</div>
|
| 130 |
+
<div className="mt-4 flex items-center gap-3">
|
| 131 |
+
<button
|
| 132 |
+
type="submit"
|
| 133 |
+
className="px-4 py-2 rounded-md bg-blue-600 hover:bg-blue-500 text-sm text-white transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
| 134 |
+
disabled={authStatus === 'checking' || manualToken.trim() === ''}
|
| 135 |
+
>
|
| 136 |
+
Validate Token
|
| 137 |
+
</button>
|
| 138 |
+
{isAuthenticated && authToken === manualToken && (
|
| 139 |
+
<span className="text-xs text-emerald-400">Active token</span>
|
| 140 |
+
)}
|
| 141 |
+
</div>
|
| 142 |
+
{authError && (
|
| 143 |
+
<p className="mt-3 text-xs text-red-400">{authError}</p>
|
| 144 |
+
)}
|
| 145 |
+
</form>
|
| 146 |
+
</div>
|
| 147 |
+
|
| 148 |
+
<form onSubmit={handleSubmit} className="space-y-6">
|
| 149 |
+
<div className="grid grid-cols-1 gap-6 sm:grid-cols-2">
|
| 150 |
+
<div>
|
| 151 |
+
<div className="space-y-4">
|
| 152 |
+
<div>
|
| 153 |
+
<label htmlFor="TRAINING_FOLDER" className="block text-sm font-medium mb-2">
|
| 154 |
+
Training Folder Path
|
| 155 |
+
<div className="text-gray-500 text-sm ml-1">
|
| 156 |
+
We will store your training information here. Must be an absolute path. If blank, it will default
|
| 157 |
+
to the output folder in the project root.
|
| 158 |
+
</div>
|
| 159 |
+
</label>
|
| 160 |
+
<input
|
| 161 |
+
type="text"
|
| 162 |
+
id="TRAINING_FOLDER"
|
| 163 |
+
name="TRAINING_FOLDER"
|
| 164 |
+
value={settings.TRAINING_FOLDER}
|
| 165 |
+
onChange={handleChange}
|
| 166 |
+
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
| 167 |
+
placeholder="Enter training folder path"
|
| 168 |
+
/>
|
| 169 |
+
</div>
|
| 170 |
+
|
| 171 |
+
<div>
|
| 172 |
+
<label htmlFor="DATASETS_FOLDER" className="block text-sm font-medium mb-2">
|
| 173 |
+
Dataset Folder Path
|
| 174 |
+
<div className="text-gray-500 text-sm ml-1">
|
| 175 |
+
Where we store and find your datasets.{' '}
|
| 176 |
+
<span className="text-orange-800">
|
| 177 |
+
Warning: This software may modify datasets so it is recommended you keep a backup somewhere else
|
| 178 |
+
or have a dedicated folder for this software.
|
| 179 |
+
</span>
|
| 180 |
+
</div>
|
| 181 |
+
</label>
|
| 182 |
+
<input
|
| 183 |
+
type="text"
|
| 184 |
+
id="DATASETS_FOLDER"
|
| 185 |
+
name="DATASETS_FOLDER"
|
| 186 |
+
value={settings.DATASETS_FOLDER}
|
| 187 |
+
onChange={handleChange}
|
| 188 |
+
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
| 189 |
+
placeholder="Enter datasets folder path"
|
| 190 |
+
/>
|
| 191 |
+
</div>
|
| 192 |
+
</div>
|
| 193 |
+
</div>
|
| 194 |
+
<div>
|
| 195 |
+
<div className="space-y-4">
|
| 196 |
+
<h3 className="text-lg font-medium mb-4">Hugging Face Jobs (Cloud Training)</h3>
|
| 197 |
+
|
| 198 |
+
<div>
|
| 199 |
+
<label htmlFor="HF_JOBS_NAMESPACE" className="block text-sm font-medium mb-2">
|
| 200 |
+
HF Jobs Namespace (optional)
|
| 201 |
+
<div className="text-gray-500 text-sm ml-1">
|
| 202 |
+
Leave blank to default to the account associated with your Hugging Face token.
|
| 203 |
+
</div>
|
| 204 |
+
</label>
|
| 205 |
+
<input
|
| 206 |
+
type="text"
|
| 207 |
+
id="HF_JOBS_NAMESPACE"
|
| 208 |
+
name="HF_JOBS_NAMESPACE"
|
| 209 |
+
value={settings.HF_JOBS_NAMESPACE}
|
| 210 |
+
onChange={handleChange}
|
| 211 |
+
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
| 212 |
+
placeholder="e.g. your-username or your-org"
|
| 213 |
+
/>
|
| 214 |
+
</div>
|
| 215 |
+
|
| 216 |
+
<div>
|
| 217 |
+
<label htmlFor="HF_JOBS_DEFAULT_HARDWARE" className="block text-sm font-medium mb-2">
|
| 218 |
+
Default Hardware
|
| 219 |
+
<div className="text-gray-500 text-sm ml-1">
|
| 220 |
+
Default hardware configuration for cloud training jobs.
|
| 221 |
+
</div>
|
| 222 |
+
</label>
|
| 223 |
+
<select
|
| 224 |
+
id="HF_JOBS_DEFAULT_HARDWARE"
|
| 225 |
+
name="HF_JOBS_DEFAULT_HARDWARE"
|
| 226 |
+
value={settings.HF_JOBS_DEFAULT_HARDWARE}
|
| 227 |
+
onChange={(e) => setSettings(prev => ({ ...prev, HF_JOBS_DEFAULT_HARDWARE: e.target.value }))}
|
| 228 |
+
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
| 229 |
+
>
|
| 230 |
+
<option value="cpu-basic">CPU Basic</option>
|
| 231 |
+
<option value="cpu-upgrade">CPU Upgrade</option>
|
| 232 |
+
<option value="t4-small">T4 Small</option>
|
| 233 |
+
<option value="t4-medium">T4 Medium</option>
|
| 234 |
+
<option value="l4x1">L4x1</option>
|
| 235 |
+
<option value="l4x4">L4x4</option>
|
| 236 |
+
<option value="a10g-small">A10G Small</option>
|
| 237 |
+
<option value="a10g-large">A10G Large</option>
|
| 238 |
+
<option value="a10g-largex2">A10G Large x2</option>
|
| 239 |
+
<option value="a10g-largex4">A10G Large x4</option>
|
| 240 |
+
<option value="a100-large">A100 Large</option>
|
| 241 |
+
<option value="v5e-1x1">TPU v5e-1x1</option>
|
| 242 |
+
<option value="v5e-2x2">TPU v5e-2x2</option>
|
| 243 |
+
<option value="v5e-2x4">TPU v5e-2x4</option>
|
| 244 |
+
</select>
|
| 245 |
+
</div>
|
| 246 |
+
</div>
|
| 247 |
+
</div>
|
| 248 |
+
</div>
|
| 249 |
+
|
| 250 |
+
<button
|
| 251 |
+
type="submit"
|
| 252 |
+
disabled={status === 'saving'}
|
| 253 |
+
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
| 254 |
+
>
|
| 255 |
+
{status === 'saving' ? 'Saving...' : 'Save Settings'}
|
| 256 |
+
</button>
|
| 257 |
+
|
| 258 |
+
{status === 'success' && <p className="text-green-500 text-center">Settings saved successfully!</p>}
|
| 259 |
+
{status === 'error' && <p className="text-red-500 text-center">Error saving settings. Please try again.</p>}
|
| 260 |
+
</form>
|
| 261 |
+
</MainContent>
|
| 262 |
+
</>
|
| 263 |
+
);
|
| 264 |
+
}
|
src/components/AddImagesModal.tsx
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
import { createGlobalState } from 'react-global-hooks';
|
| 3 |
+
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
|
| 4 |
+
import { FaUpload } from 'react-icons/fa';
|
| 5 |
+
import { useCallback, useState } from 'react';
|
| 6 |
+
import { useDropzone } from 'react-dropzone';
|
| 7 |
+
import { apiClient } from '@/utils/api';
|
| 8 |
+
|
| 9 |
+
export interface AddImagesModalState {
|
| 10 |
+
datasetName: string;
|
| 11 |
+
onComplete?: () => void;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export const addImagesModalState = createGlobalState<AddImagesModalState | null>(null);
|
| 15 |
+
|
| 16 |
+
export const openImagesModal = (datasetName: string, onComplete: () => void) => {
|
| 17 |
+
addImagesModalState.set({ datasetName, onComplete });
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
export default function AddImagesModal() {
|
| 21 |
+
const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
|
| 22 |
+
const [uploadProgress, setUploadProgress] = useState<number>(0);
|
| 23 |
+
const [isUploading, setIsUploading] = useState<boolean>(false);
|
| 24 |
+
const open = addImagesModalInfo !== null;
|
| 25 |
+
|
| 26 |
+
const onCancel = () => {
|
| 27 |
+
if (!isUploading) {
|
| 28 |
+
setAddImagesModalInfo(null);
|
| 29 |
+
}
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
const onDone = () => {
|
| 33 |
+
if (addImagesModalInfo?.onComplete && !isUploading) {
|
| 34 |
+
addImagesModalInfo.onComplete();
|
| 35 |
+
setAddImagesModalInfo(null);
|
| 36 |
+
}
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
const onDrop = useCallback(
|
| 40 |
+
async (acceptedFiles: File[]) => {
|
| 41 |
+
if (acceptedFiles.length === 0) return;
|
| 42 |
+
|
| 43 |
+
setIsUploading(true);
|
| 44 |
+
setUploadProgress(0);
|
| 45 |
+
|
| 46 |
+
const formData = new FormData();
|
| 47 |
+
acceptedFiles.forEach(file => {
|
| 48 |
+
formData.append('files', file);
|
| 49 |
+
});
|
| 50 |
+
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
|
| 51 |
+
|
| 52 |
+
try {
|
| 53 |
+
await apiClient.post(`/api/datasets/upload`, formData, {
|
| 54 |
+
headers: {
|
| 55 |
+
'Content-Type': 'multipart/form-data',
|
| 56 |
+
},
|
| 57 |
+
onUploadProgress: progressEvent => {
|
| 58 |
+
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
|
| 59 |
+
setUploadProgress(percentCompleted);
|
| 60 |
+
},
|
| 61 |
+
timeout: 0, // Disable timeout
|
| 62 |
+
});
|
| 63 |
+
|
| 64 |
+
onDone();
|
| 65 |
+
} catch (error) {
|
| 66 |
+
console.error('Upload failed:', error);
|
| 67 |
+
} finally {
|
| 68 |
+
setIsUploading(false);
|
| 69 |
+
setUploadProgress(0);
|
| 70 |
+
}
|
| 71 |
+
},
|
| 72 |
+
[addImagesModalInfo],
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
| 76 |
+
onDrop,
|
| 77 |
+
accept: {
|
| 78 |
+
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'],
|
| 79 |
+
'video/*': ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv'],
|
| 80 |
+
'text/*': ['.txt'],
|
| 81 |
+
},
|
| 82 |
+
multiple: true,
|
| 83 |
+
});
|
| 84 |
+
|
| 85 |
+
return (
|
| 86 |
+
<Dialog open={open} onClose={onCancel} className="relative z-10">
|
| 87 |
+
<DialogBackdrop
|
| 88 |
+
transition
|
| 89 |
+
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
|
| 90 |
+
/>
|
| 91 |
+
|
| 92 |
+
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
|
| 93 |
+
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
| 94 |
+
<DialogPanel
|
| 95 |
+
transition
|
| 96 |
+
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
|
| 97 |
+
>
|
| 98 |
+
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
|
| 99 |
+
<div className="text-center">
|
| 100 |
+
<DialogTitle as="h3" className="text-base font-semibold text-gray-200 mb-4">
|
| 101 |
+
Add Images to: {addImagesModalInfo?.datasetName}
|
| 102 |
+
</DialogTitle>
|
| 103 |
+
<div className="w-full">
|
| 104 |
+
<div
|
| 105 |
+
{...getRootProps()}
|
| 106 |
+
className={`h-40 w-full flex flex-col items-center justify-center border-2 border-dashed rounded-lg cursor-pointer transition-colors duration-200
|
| 107 |
+
${isDragActive ? 'border-blue-500 bg-blue-50/10' : 'border-gray-600'}`}
|
| 108 |
+
>
|
| 109 |
+
<input {...getInputProps()} />
|
| 110 |
+
<FaUpload className="size-8 mb-3 text-gray-400" />
|
| 111 |
+
<p className="text-sm text-gray-200 text-center">
|
| 112 |
+
{isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'}
|
| 113 |
+
</p>
|
| 114 |
+
</div>
|
| 115 |
+
{isUploading && (
|
| 116 |
+
<div className="mt-4">
|
| 117 |
+
<div className="w-full bg-gray-700 rounded-full h-2.5">
|
| 118 |
+
<div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
|
| 119 |
+
</div>
|
| 120 |
+
<p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
|
| 121 |
+
</div>
|
| 122 |
+
)}
|
| 123 |
+
</div>
|
| 124 |
+
</div>
|
| 125 |
+
</div>
|
| 126 |
+
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
|
| 127 |
+
<button
|
| 128 |
+
type="button"
|
| 129 |
+
onClick={onDone}
|
| 130 |
+
disabled={isUploading}
|
| 131 |
+
className={`inline-flex w-full justify-center rounded-md bg-slate-600 px-3 py-2 text-sm font-semibold text-white shadow-xs sm:ml-3 sm:w-auto
|
| 132 |
+
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
|
| 133 |
+
>
|
| 134 |
+
Done
|
| 135 |
+
</button>
|
| 136 |
+
<button
|
| 137 |
+
type="button"
|
| 138 |
+
data-autofocus
|
| 139 |
+
onClick={onCancel}
|
| 140 |
+
disabled={isUploading}
|
| 141 |
+
className={`mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0
|
| 142 |
+
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
|
| 143 |
+
>
|
| 144 |
+
Cancel
|
| 145 |
+
</button>
|
| 146 |
+
</div>
|
| 147 |
+
</DialogPanel>
|
| 148 |
+
</div>
|
| 149 |
+
</div>
|
| 150 |
+
</Dialog>
|
| 151 |
+
);
|
| 152 |
+
}
|
src/components/AddSingleImageModal.tsx
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
import { createGlobalState } from 'react-global-hooks';
|
| 3 |
+
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
|
| 4 |
+
import { FaUpload } from 'react-icons/fa';
|
| 5 |
+
import { useCallback, useState } from 'react';
|
| 6 |
+
import { useDropzone } from 'react-dropzone';
|
| 7 |
+
import { apiClient } from '@/utils/api';
|
| 8 |
+
|
| 9 |
+
export interface AddSingleImageModalState {
|
| 10 |
+
|
| 11 |
+
onComplete?: (imagePath: string|null) => void;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export const addSingleImageModalState = createGlobalState<AddSingleImageModalState | null>(null);
|
| 15 |
+
|
| 16 |
+
export const openAddImageModal = (onComplete: (imagePath: string|null) => void) => {
|
| 17 |
+
addSingleImageModalState.set({onComplete });
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
export default function AddSingleImageModal() {
|
| 21 |
+
const [addSingleImageModalInfo, setAddSingleImageModalInfo] = addSingleImageModalState.use();
|
| 22 |
+
const [uploadProgress, setUploadProgress] = useState<number>(0);
|
| 23 |
+
const [isUploading, setIsUploading] = useState<boolean>(false);
|
| 24 |
+
const open = addSingleImageModalInfo !== null;
|
| 25 |
+
|
| 26 |
+
const onCancel = () => {
|
| 27 |
+
if (!isUploading) {
|
| 28 |
+
setAddSingleImageModalInfo(null);
|
| 29 |
+
}
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
const onDone = (imagePath: string|null) => {
|
| 33 |
+
if (addSingleImageModalInfo?.onComplete && !isUploading) {
|
| 34 |
+
addSingleImageModalInfo.onComplete(imagePath);
|
| 35 |
+
setAddSingleImageModalInfo(null);
|
| 36 |
+
}
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
const onDrop = useCallback(
|
| 40 |
+
async (acceptedFiles: File[]) => {
|
| 41 |
+
if (acceptedFiles.length === 0) return;
|
| 42 |
+
|
| 43 |
+
setIsUploading(true);
|
| 44 |
+
setUploadProgress(0);
|
| 45 |
+
|
| 46 |
+
const formData = new FormData();
|
| 47 |
+
acceptedFiles.forEach(file => {
|
| 48 |
+
formData.append('files', file);
|
| 49 |
+
});
|
| 50 |
+
|
| 51 |
+
try {
|
| 52 |
+
const resp = await apiClient.post(`/api/img/upload`, formData, {
|
| 53 |
+
headers: {
|
| 54 |
+
'Content-Type': 'multipart/form-data',
|
| 55 |
+
},
|
| 56 |
+
onUploadProgress: progressEvent => {
|
| 57 |
+
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
|
| 58 |
+
setUploadProgress(percentCompleted);
|
| 59 |
+
},
|
| 60 |
+
timeout: 0, // Disable timeout
|
| 61 |
+
});
|
| 62 |
+
console.log('Upload successful:', resp.data);
|
| 63 |
+
|
| 64 |
+
onDone(resp.data.files[0] || null);
|
| 65 |
+
} catch (error) {
|
| 66 |
+
console.error('Upload failed:', error);
|
| 67 |
+
} finally {
|
| 68 |
+
setIsUploading(false);
|
| 69 |
+
setUploadProgress(0);
|
| 70 |
+
}
|
| 71 |
+
},
|
| 72 |
+
[addSingleImageModalInfo],
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
| 76 |
+
onDrop,
|
| 77 |
+
accept: {
|
| 78 |
+
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'],
|
| 79 |
+
},
|
| 80 |
+
multiple: false,
|
| 81 |
+
});
|
| 82 |
+
|
| 83 |
+
return (
|
| 84 |
+
<Dialog open={open} onClose={onCancel} className="relative z-10">
|
| 85 |
+
<DialogBackdrop
|
| 86 |
+
transition
|
| 87 |
+
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
|
| 88 |
+
/>
|
| 89 |
+
|
| 90 |
+
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
|
| 91 |
+
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
| 92 |
+
<DialogPanel
|
| 93 |
+
transition
|
| 94 |
+
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
|
| 95 |
+
>
|
| 96 |
+
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
|
| 97 |
+
<div className="text-center">
|
| 98 |
+
<DialogTitle as="h3" className="text-base font-semibold text-gray-200 mb-4">
|
| 99 |
+
Add Control Image
|
| 100 |
+
</DialogTitle>
|
| 101 |
+
<div className="w-full">
|
| 102 |
+
<div
|
| 103 |
+
{...getRootProps()}
|
| 104 |
+
className={`h-40 w-full flex flex-col items-center justify-center border-2 border-dashed rounded-lg cursor-pointer transition-colors duration-200
|
| 105 |
+
${isDragActive ? 'border-blue-500 bg-blue-50/10' : 'border-gray-600'}`}
|
| 106 |
+
>
|
| 107 |
+
<input {...getInputProps()} />
|
| 108 |
+
<FaUpload className="size-8 mb-3 text-gray-400" />
|
| 109 |
+
<p className="text-sm text-gray-200 text-center">
|
| 110 |
+
{isDragActive ? 'Drop the image here...' : 'Drag & drop an image here, or click to select one'}
|
| 111 |
+
</p>
|
| 112 |
+
</div>
|
| 113 |
+
{isUploading && (
|
| 114 |
+
<div className="mt-4">
|
| 115 |
+
<div className="w-full bg-gray-700 rounded-full h-2.5">
|
| 116 |
+
<div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
|
| 117 |
+
</div>
|
| 118 |
+
<p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
|
| 119 |
+
</div>
|
| 120 |
+
)}
|
| 121 |
+
</div>
|
| 122 |
+
</div>
|
| 123 |
+
</div>
|
| 124 |
+
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
|
| 125 |
+
<button
|
| 126 |
+
type="button"
|
| 127 |
+
data-autofocus
|
| 128 |
+
onClick={onCancel}
|
| 129 |
+
disabled={isUploading}
|
| 130 |
+
className={`mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0
|
| 131 |
+
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
|
| 132 |
+
>
|
| 133 |
+
Cancel
|
| 134 |
+
</button>
|
| 135 |
+
</div>
|
| 136 |
+
</DialogPanel>
|
| 137 |
+
</div>
|
| 138 |
+
</div>
|
| 139 |
+
</Dialog>
|
| 140 |
+
);
|
| 141 |
+
}
|
src/components/AuthWrapper.tsx
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'use client';
|
| 2 |
+
|
| 3 |
+
import { useState, useEffect, useRef } from 'react';
|
| 4 |
+
import { apiClient, isAuthorizedState } from '@/utils/api';
|
| 5 |
+
import { createGlobalState } from 'react-global-hooks';
|
| 6 |
+
|
| 7 |
+
interface AuthWrapperProps {
|
| 8 |
+
authRequired: boolean;
|
| 9 |
+
children: React.ReactNode | React.ReactNode[];
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) {
|
| 13 |
+
const [token, setToken] = useState('');
|
| 14 |
+
// start with true, and deauth if needed
|
| 15 |
+
const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use();
|
| 16 |
+
const [isLoading, setIsLoading] = useState(false);
|
| 17 |
+
const [error, setError] = useState('');
|
| 18 |
+
const [isBrowser, setIsBrowser] = useState(false);
|
| 19 |
+
const inputRef = useRef<HTMLInputElement>(null);
|
| 20 |
+
|
| 21 |
+
const isAuthorized = authRequired ? isAuthorizedGlobal : true;
|
| 22 |
+
|
| 23 |
+
// Set isBrowser to true when component mounts
|
| 24 |
+
useEffect(() => {
|
| 25 |
+
setIsBrowser(true);
|
| 26 |
+
// Get token from localStorage only after component has mounted
|
| 27 |
+
const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
|
| 28 |
+
setToken(storedToken);
|
| 29 |
+
checkAuth();
|
| 30 |
+
}, []);
|
| 31 |
+
|
| 32 |
+
// auto focus on input when not authorized
|
| 33 |
+
useEffect(() => {
|
| 34 |
+
if (isAuthorized) {
|
| 35 |
+
return;
|
| 36 |
+
}
|
| 37 |
+
setTimeout(() => {
|
| 38 |
+
if (inputRef.current) {
|
| 39 |
+
inputRef.current.focus();
|
| 40 |
+
}
|
| 41 |
+
}, 100);
|
| 42 |
+
}, [isAuthorized]);
|
| 43 |
+
|
| 44 |
+
const checkAuth = async () => {
|
| 45 |
+
// always get current stored token here to avoid state race conditions
|
| 46 |
+
const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
|
| 47 |
+
if (!authRequired || isLoading || currentToken === '') {
|
| 48 |
+
return;
|
| 49 |
+
}
|
| 50 |
+
setIsLoading(true);
|
| 51 |
+
setError('');
|
| 52 |
+
try {
|
| 53 |
+
const response = await apiClient.get('/api/auth');
|
| 54 |
+
if (response.data.isAuthenticated) {
|
| 55 |
+
setIsAuthorized(true);
|
| 56 |
+
} else {
|
| 57 |
+
setIsAuthorized(false);
|
| 58 |
+
setError('Invalid token. Please try again.');
|
| 59 |
+
}
|
| 60 |
+
} catch (err) {
|
| 61 |
+
setIsAuthorized(false);
|
| 62 |
+
console.log(err);
|
| 63 |
+
setError('Invalid token. Please try again.');
|
| 64 |
+
}
|
| 65 |
+
setIsLoading(false);
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
const handleSubmit = async (e: React.FormEvent) => {
|
| 69 |
+
e.preventDefault();
|
| 70 |
+
setError('');
|
| 71 |
+
|
| 72 |
+
if (!token.trim()) {
|
| 73 |
+
setError('Please enter your token');
|
| 74 |
+
return;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
if (isBrowser) {
|
| 78 |
+
localStorage.setItem('AI_TOOLKIT_AUTH', token);
|
| 79 |
+
checkAuth();
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
if (isAuthorized) {
|
| 84 |
+
return <>{children}</>;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
return (
|
| 88 |
+
<div className="flex min-h-screen bg-gray-900 text-gray-100 absolute top-0 left-0 right-0 bottom-0 scroll-auto">
|
| 89 |
+
{/* Left side - decorative or brand area */}
|
| 90 |
+
<div className="hidden lg:flex lg:w-1/2 bg-gray-800 flex-col justify-center items-center p-12">
|
| 91 |
+
<div className="mb-4">
|
| 92 |
+
{/* Replace with your own logo */}
|
| 93 |
+
<div className="flex items-center justify-center">
|
| 94 |
+
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
|
| 95 |
+
</div>
|
| 96 |
+
</div>
|
| 97 |
+
<h1 className="text-4xl mb-6">AI Toolkit</h1>
|
| 98 |
+
</div>
|
| 99 |
+
|
| 100 |
+
{/* Right side - login form */}
|
| 101 |
+
<div className="w-full lg:w-1/2 flex flex-col justify-center items-center p-8 sm:p-12">
|
| 102 |
+
<div className="w-full max-w-md">
|
| 103 |
+
<div className="lg:hidden flex justify-center mb-4">
|
| 104 |
+
{/* Mobile logo */}
|
| 105 |
+
<div className="flex items-center justify-center">
|
| 106 |
+
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
|
| 107 |
+
</div>
|
| 108 |
+
</div>
|
| 109 |
+
|
| 110 |
+
<h2 className="text-3xl text-center mb-2 lg:hidden">AI Toolkit</h2>
|
| 111 |
+
|
| 112 |
+
<form onSubmit={handleSubmit} className="space-y-6">
|
| 113 |
+
<div>
|
| 114 |
+
<label htmlFor="token" className="block text-sm font-medium text-gray-400 mb-2">
|
| 115 |
+
Password
|
| 116 |
+
</label>
|
| 117 |
+
<input
|
| 118 |
+
id="token"
|
| 119 |
+
name="token"
|
| 120 |
+
type="password"
|
| 121 |
+
autoComplete="off"
|
| 122 |
+
required
|
| 123 |
+
value={token}
|
| 124 |
+
ref={inputRef}
|
| 125 |
+
onChange={e => setToken(e.target.value)}
|
| 126 |
+
className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200"
|
| 127 |
+
placeholder="Enter your password"
|
| 128 |
+
/>
|
| 129 |
+
<div className='text-gray-500 text-xs mt-2'>
|
| 130 |
+
The password is set with the environment variable AI_TOOLKIT_AUTH, the default is the super secure secret word "password"
|
| 131 |
+
</div>
|
| 132 |
+
</div>
|
| 133 |
+
|
| 134 |
+
{error && (
|
| 135 |
+
<div className="p-3 bg-red-900/50 border border-red-800 rounded-lg text-red-200 text-sm">{error}</div>
|
| 136 |
+
)}
|
| 137 |
+
|
| 138 |
+
<button
|
| 139 |
+
type="submit"
|
| 140 |
+
disabled={isLoading}
|
| 141 |
+
className="w-full py-3 px-4 bg-blue-600 hover:bg-blue-700 rounded-lg text-white font-medium focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 transition duration-200 flex items-center justify-center"
|
| 142 |
+
>
|
| 143 |
+
{isLoading ? (
|
| 144 |
+
<svg
|
| 145 |
+
className="animate-spin h-5 w-5 text-white"
|
| 146 |
+
xmlns="http://www.w3.org/2000/svg"
|
| 147 |
+
fill="none"
|
| 148 |
+
viewBox="0 0 24 24"
|
| 149 |
+
>
|
| 150 |
+
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
|
| 151 |
+
<path
|
| 152 |
+
className="opacity-75"
|
| 153 |
+
fill="currentColor"
|
| 154 |
+
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
| 155 |
+
></path>
|
| 156 |
+
</svg>
|
| 157 |
+
) : (
|
| 158 |
+
'Check Password'
|
| 159 |
+
)}
|
| 160 |
+
</button>
|
| 161 |
+
</form>
|
| 162 |
+
</div>
|
| 163 |
+
</div>
|
| 164 |
+
</div>
|
| 165 |
+
);
|
| 166 |
+
}
|
src/components/Card.tsx
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface CardProps {
|
| 2 |
+
title?: string;
|
| 3 |
+
children?: React.ReactNode;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
const Card: React.FC<CardProps> = ({ title, children }) => {
|
| 7 |
+
return (
|
| 8 |
+
<section className="space-y-2 px-4 pb-4 pt-2 bg-gray-900 rounded-lg">
|
| 9 |
+
{title && <h2 className="text-lg mb-2 font-semibold uppercase text-gray-500">{title}</h2>}
|
| 10 |
+
{children ? children : null}
|
| 11 |
+
</section>
|
| 12 |
+
);
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
export default Card;
|