liuw15 commited on
Commit
b86a7bf
·
1 Parent(s): e6f591e

补全sd接口

Browse files
Files changed (2) hide show
  1. src/routes/sd.js +205 -0
  2. src/server/index.js +3 -71
src/routes/sd.js ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import express from 'express';
2
+ import { getAvailableModels, generateImageForSD } from '../api/client.js';
3
+ import { generateRequestBody } from '../utils/utils.js';
4
+ import tokenManager from '../auth/token_manager.js';
5
+ import logger from '../utils/logger.js';
6
+
7
+ const router = express.Router();
8
+
9
+ // 静态数据
10
+ const SD_MOCK_DATA = {
11
+ options: {
12
+ sd_model_checkpoint: 'gemini-3-pro-image',
13
+ sd_vae: 'auto',
14
+ CLIP_stop_at_last_layers: 1
15
+ },
16
+ samplers: [
17
+ { name: 'Euler a', aliases: ['k_euler_a'] },
18
+ { name: 'Euler', aliases: ['k_euler'] },
19
+ { name: 'DPM++ 2M', aliases: ['k_dpmpp_2m'] },
20
+ { name: 'DPM++ SDE', aliases: ['k_dpmpp_sde'] }
21
+ ],
22
+ schedulers: [
23
+ { name: 'Automatic', label: 'Automatic' },
24
+ { name: 'Uniform', label: 'Uniform' },
25
+ { name: 'Karras', label: 'Karras' },
26
+ { name: 'Exponential', label: 'Exponential' }
27
+ ],
28
+ upscalers: [
29
+ { name: 'None', model_name: null, scale: 1 },
30
+ { name: 'Lanczos', model_name: null, scale: 4 },
31
+ { name: 'ESRGAN_4x', model_name: 'ESRGAN_4x', scale: 4 }
32
+ ],
33
+ latentUpscaleModes: [
34
+ { name: 'Latent' },
35
+ { name: 'Latent (antialiased)' },
36
+ { name: 'Latent (bicubic)' },
37
+ { name: 'Latent (nearest)' }
38
+ ],
39
+ vae: [
40
+ { model_name: 'auto', filename: 'auto' },
41
+ { model_name: 'None', filename: 'None' }
42
+ ],
43
+ modules: [
44
+ { name: 'none', path: null },
45
+ { name: 'LoRA', path: 'lora' }
46
+ ],
47
+ loras: [
48
+ { name: 'example_lora_v1', alias: 'example_lora_v1', path: 'example_lora_v1.safetensors' },
49
+ { name: 'style_lora', alias: 'style_lora', path: 'style_lora.safetensors' }
50
+ ],
51
+ embeddings: [
52
+ { name: 'EasyNegative', step: 1, sd_checkpoint: null, sd_checkpoint_name: null },
53
+ { name: 'badhandv4', step: 1, sd_checkpoint: null, sd_checkpoint_name: null }
54
+ ],
55
+ hypernetworks: [
56
+ { name: 'example_hypernetwork', path: 'example_hypernetwork.pt' }
57
+ ],
58
+ scripts: [
59
+ { name: 'None', is_alwayson: false, is_img2img: false },
60
+ { name: 'X/Y/Z plot', is_alwayson: false, is_img2img: false }
61
+ ],
62
+ progress: {
63
+ progress: 0,
64
+ eta_relative: 0,
65
+ state: { skipped: false, interrupted: false, job: '', job_count: 0, job_timestamp: '0', job_no: 0 },
66
+ current_image: null,
67
+ textinfo: null
68
+ }
69
+ };
70
+
71
+ // 构建图片生成请求体
72
+ function buildImageRequestBody(prompt, token) {
73
+ const messages = [{ role: 'user', content: prompt }];
74
+ const requestBody = generateRequestBody(messages, 'gemini-3-pro-image', {}, null, token);
75
+ requestBody.request.generationConfig = { candidateCount: 1 };
76
+ requestBody.requestType = 'image_gen';
77
+ delete requestBody.request.systemInstruction;
78
+ delete requestBody.request.tools;
79
+ delete requestBody.request.toolConfig;
80
+ return requestBody;
81
+ }
82
+
83
+ // GET 路由
84
+ router.get('/sd-models', async (req, res) => {
85
+ try {
86
+ const models = await getAvailableModels();
87
+ const imageModels = models.data
88
+ .filter(m => m.id.includes('-image'))
89
+ .map(m => ({
90
+ title: m.id,
91
+ model_name: m.id,
92
+ hash: null,
93
+ sha256: null,
94
+ filename: m.id,
95
+ config: null
96
+ }));
97
+ res.json(imageModels);
98
+ } catch (error) {
99
+ logger.error('获取SD模型列表失败:', error.message);
100
+ res.status(500).json({ error: error.message });
101
+ }
102
+ });
103
+
104
+ router.get('/options', (req, res) => res.json(SD_MOCK_DATA.options));
105
+ router.get('/samplers', (req, res) => res.json(SD_MOCK_DATA.samplers));
106
+ router.get('/schedulers', (req, res) => res.json(SD_MOCK_DATA.schedulers));
107
+ router.get('/upscalers', (req, res) => res.json(SD_MOCK_DATA.upscalers));
108
+ router.get('/latent-upscale-modes', (req, res) => res.json(SD_MOCK_DATA.latentUpscaleModes));
109
+ router.get('/sd-vae', (req, res) => res.json(SD_MOCK_DATA.vae));
110
+ router.get('/sd-modules', (req, res) => res.json(SD_MOCK_DATA.modules));
111
+ router.get('/loras', (req, res) => res.json(SD_MOCK_DATA.loras));
112
+ router.get('/embeddings', (req, res) => res.json({ loaded: SD_MOCK_DATA.embeddings, skipped: {} }));
113
+ router.get('/hypernetworks', (req, res) => res.json(SD_MOCK_DATA.hypernetworks));
114
+ router.get('/scripts', (req, res) => res.json({ txt2img: SD_MOCK_DATA.scripts, img2img: SD_MOCK_DATA.scripts }));
115
+ router.get('/script-info', (req, res) => res.json([]));
116
+ router.get('/progress', (req, res) => res.json(SD_MOCK_DATA.progress));
117
+ router.get('/cmd-flags', (req, res) => res.json({}));
118
+ router.get('/memory', (req, res) => res.json({ ram: { free: 8589934592, used: 8589934592, total: 17179869184 }, cuda: { system: { free: 0, used: 0, total: 0 } } }));
119
+
120
+ // POST 路由
121
+ router.post('/img2img', async (req, res) => {
122
+ const { prompt, init_images } = req.body;
123
+
124
+ try {
125
+ if (!prompt) {
126
+ return res.status(400).json({ error: 'prompt is required' });
127
+ }
128
+
129
+ const token = await tokenManager.getToken();
130
+ if (!token) {
131
+ throw new Error('没有可用的token');
132
+ }
133
+
134
+ // 构建包含图片的消息
135
+ const content = [{ type: 'text', text: prompt }];
136
+ if (init_images && init_images.length > 0) {
137
+ init_images.forEach(img => {
138
+ const format = img.startsWith('/9j/') ? 'jpeg' : 'png';
139
+ content.push({ type: 'image_url', image_url: { url: `data:image/${format};base64,${img}` } });
140
+ });
141
+ }
142
+
143
+ const messages = [{ role: 'user', content }];
144
+ const requestBody = generateRequestBody(messages, 'gemini-3-pro-image', {}, null, token);
145
+ requestBody.request.generationConfig = { candidateCount: 1 };
146
+ requestBody.requestType = 'image_gen';
147
+ delete requestBody.request.systemInstruction;
148
+ delete requestBody.request.tools;
149
+ delete requestBody.request.toolConfig;
150
+
151
+ const images = await generateImageForSD(requestBody, token);
152
+
153
+ if (images.length === 0) {
154
+ throw new Error('未生成图片');
155
+ }
156
+
157
+ res.json({
158
+ images,
159
+ parameters: req.body,
160
+ info: JSON.stringify({ prompt })
161
+ });
162
+ } catch (error) {
163
+ logger.error('SD图生图失败:', error.message);
164
+ res.status(500).json({ error: error.message });
165
+ }
166
+ });
167
+
168
+ router.post('/txt2img', async (req, res) => {
169
+ const { prompt, negative_prompt, steps, cfg_scale, width, height, seed, sampler_name } = req.body;
170
+
171
+ try {
172
+ if (!prompt) {
173
+ return res.status(400).json({ error: 'prompt is required' });
174
+ }
175
+
176
+ const token = await tokenManager.getToken();
177
+ if (!token) {
178
+ throw new Error('没有可用的token');
179
+ }
180
+
181
+ const requestBody = buildImageRequestBody(prompt, token);
182
+ const images = await generateImageForSD(requestBody, token);
183
+
184
+ if (images.length === 0) {
185
+ throw new Error('未生成图片');
186
+ }
187
+
188
+ res.json({
189
+ images,
190
+ parameters: { prompt, negative_prompt, steps, cfg_scale, width, height, seed, sampler_name },
191
+ info: JSON.stringify({ prompt, seed: seed || -1 })
192
+ });
193
+ } catch (error) {
194
+ logger.error('SD生图失败:', error.message);
195
+ res.status(500).json({ error: error.message });
196
+ }
197
+ });
198
+
199
+ router.post('/options', (req, res) => res.json({}));
200
+ router.post('/refresh-checkpoints', (req, res) => res.json(null));
201
+ router.post('/refresh-loras', (req, res) => res.json(null));
202
+ router.post('/interrupt', (req, res) => res.json(null));
203
+ router.post('/skip', (req, res) => res.json(null));
204
+
205
+ export default router;
src/server/index.js CHANGED
@@ -7,6 +7,7 @@ import logger from '../utils/logger.js';
7
  import config from '../config/config.js';
8
  import tokenManager from '../auth/token_manager.js';
9
  import adminRouter from '../routes/admin.js';
 
10
 
11
  const __filename = fileURLToPath(import.meta.url);
12
  const __dirname = path.dirname(__filename);
@@ -63,7 +64,7 @@ app.use((err, req, res, next) => {
63
  });
64
 
65
  app.use((req, res, next) => {
66
- const ignorePaths = ['/images', '/favicon.ico', '/.well-known'];
67
  if (!ignorePaths.some(path => req.path.startsWith(path))) {
68
  const start = Date.now();
69
  res.on('finish', () => {
@@ -72,6 +73,7 @@ app.use((req, res, next) => {
72
  }
73
  next();
74
  });
 
75
 
76
  app.use((req, res, next) => {
77
  if (req.path.startsWith('/v1/')) {
@@ -98,76 +100,6 @@ app.get('/v1/models', async (req, res) => {
98
  }
99
  });
100
 
101
- // ==================== Stable Diffusion API ====================
102
-
103
- app.get('/sdapi/v1/sd-models', async (req, res) => {
104
- try {
105
- const models = await getAvailableModels();
106
- const imageModels = models.data
107
- .filter(m => m.id.includes('-image'))
108
- .map(m => ({
109
- title: m.id,
110
- model_name: m.id,
111
- hash: null,
112
- sha256: null,
113
- filename: m.id,
114
- config: null
115
- }));
116
- res.json(imageModels);
117
- } catch (error) {
118
- logger.error('获取SD模型列表失败:', error.message);
119
- res.status(500).json({ error: error.message });
120
- }
121
- });
122
-
123
- app.post('/sdapi/v1/txt2img', async (req, res) => {
124
- const { prompt, negative_prompt, steps, cfg_scale, width, height, seed, sampler_name } = req.body;
125
-
126
- try {
127
- if (!prompt) {
128
- return res.status(400).json({ error: 'prompt is required' });
129
- }
130
-
131
- const token = await tokenManager.getToken();
132
- if (!token) {
133
- throw new Error('没有可用的token');
134
- }
135
-
136
- const model = 'gemini-3-pro-image';
137
- const messages = [{ role: 'user', content: prompt }];
138
- const requestBody = generateRequestBody(messages, model, {}, null, token);
139
-
140
- requestBody.request.generationConfig = { candidateCount: 1 };
141
- requestBody.requestType = 'image_gen';
142
- delete requestBody.request.systemInstruction;
143
- delete requestBody.request.tools;
144
- delete requestBody.request.toolConfig;
145
-
146
- const images = await generateImageForSD(requestBody, token);
147
-
148
- if (images.length === 0) {
149
- throw new Error('未生成图片');
150
- }
151
-
152
- res.json({
153
- images,
154
- parameters: { prompt, negative_prompt, steps, cfg_scale, width, height, seed, sampler_name },
155
- info: JSON.stringify({ prompt, seed: seed || -1 })
156
- });
157
- } catch (error) {
158
- logger.error('SD生图失败:', error.message);
159
- res.status(500).json({ error: error.message });
160
- }
161
- });
162
-
163
- app.get('/sdapi/v1/options', (req, res) => {
164
- res.json({
165
- sd_model_checkpoint: 'gemini-3-pro-image',
166
- sd_vae: 'auto',
167
- CLIP_stop_at_last_layers: 1
168
- });
169
- });
170
-
171
 
172
 
173
  app.post('/v1/chat/completions', async (req, res) => {
 
7
  import config from '../config/config.js';
8
  import tokenManager from '../auth/token_manager.js';
9
  import adminRouter from '../routes/admin.js';
10
+ import sdRouter from '../routes/sd.js';
11
 
12
  const __filename = fileURLToPath(import.meta.url);
13
  const __dirname = path.dirname(__filename);
 
64
  });
65
 
66
  app.use((req, res, next) => {
67
+ const ignorePaths = ['/images', '/favicon.ico', '/.well-known', '/sdapi/v1/options', '/sdapi/v1/samplers', '/sdapi/v1/schedulers', '/sdapi/v1/upscalers', '/sdapi/v1/latent-upscale-modes', '/sdapi/v1/sd-vae', '/sdapi/v1/sd-modules'];
68
  if (!ignorePaths.some(path => req.path.startsWith(path))) {
69
  const start = Date.now();
70
  res.on('finish', () => {
 
73
  }
74
  next();
75
  });
76
+ app.use('/sdapi/v1', sdRouter);
77
 
78
  app.use((req, res, next) => {
79
  if (req.path.startsWith('/v1/')) {
 
100
  }
101
  });
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  app.post('/v1/chat/completions', async (req, res) => {