khushalcodiste commited on
Commit
ea92e9f
·
1 Parent(s): debd1ce

fix: added

Browse files
Files changed (1) hide show
  1. server.js +97 -0
server.js CHANGED
@@ -81,6 +81,38 @@ async function runInference(imageBuffer, prompt, maxTokens) {
81
  return decoded[0];
82
  }
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  function queueInference(imageBuffer, prompt, maxTokens) {
85
  const task = inferenceQueue.then(() => runInference(imageBuffer, prompt, maxTokens));
86
  inferenceQueue = task.catch(() => {});
@@ -133,6 +165,31 @@ const swaggerDoc = {
133
  },
134
  },
135
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  "/inference/base64": {
137
  post: {
138
  summary: "Image inference (base64)",
@@ -199,6 +256,46 @@ app.get("/health", (req, res) => {
199
  res.json({ status: "healthy", model_loaded: model !== null });
200
  });
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  app.post("/inference", upload.single("file"), async (req, res) => {
203
  const prompt = req.body.prompt || "Describe this image in detail.";
204
  const maxTokens = parseInt(req.body.max_tokens) || 256;
 
81
  return decoded[0];
82
  }
83
 
84
+ async function runTextInference(prompt, maxTokens) {
85
+ const conversation = [
86
+ {
87
+ role: "user",
88
+ content: [{ type: "text", text: prompt }],
89
+ },
90
+ ];
91
+
92
+ const text = processor.apply_chat_template(conversation, {
93
+ add_generation_prompt: true,
94
+ });
95
+
96
+ const inputs = await processor(text);
97
+ const output = await model.generate({
98
+ ...inputs,
99
+ max_new_tokens: maxTokens,
100
+ });
101
+
102
+ const promptLength = inputs.input_ids.dims.at(-1);
103
+ const decoded = processor.batch_decode(
104
+ output.slice(null, [promptLength, null]),
105
+ { skip_special_tokens: true },
106
+ );
107
+ return decoded[0];
108
+ }
109
+
110
+ function queueTextInference(prompt, maxTokens) {
111
+ const task = inferenceQueue.then(() => runTextInference(prompt, maxTokens));
112
+ inferenceQueue = task.catch(() => {});
113
+ return task;
114
+ }
115
+
116
  function queueInference(imageBuffer, prompt, maxTokens) {
117
  const task = inferenceQueue.then(() => runInference(imageBuffer, prompt, maxTokens));
118
  inferenceQueue = task.catch(() => {});
 
165
  },
166
  },
167
  },
168
+ "/prompt": {
169
+ post: {
170
+ summary: "Text prompt inference (no image)",
171
+ requestBody: {
172
+ required: true,
173
+ content: {
174
+ "application/json": {
175
+ schema: {
176
+ type: "object",
177
+ required: ["prompt"],
178
+ properties: {
179
+ prompt: { type: "string", description: "Text prompt to send to the model" },
180
+ max_tokens: { type: "integer", default: 256 },
181
+ },
182
+ },
183
+ },
184
+ },
185
+ },
186
+ responses: {
187
+ 200: { description: "Inference result" },
188
+ 400: { description: "Invalid input" },
189
+ 503: { description: "Model not loaded" },
190
+ },
191
+ },
192
+ },
193
  "/inference/base64": {
194
  post: {
195
  summary: "Image inference (base64)",
 
256
  res.json({ status: "healthy", model_loaded: model !== null });
257
  });
258
 
259
+ app.post("/prompt", express.json(), async (req, res) => {
260
+ const prompt = req.body.prompt;
261
+ const maxTokens = parseInt(req.body.max_tokens) || 256;
262
+ log("info", "prompt_request_received", {
263
+ request_id: req.requestId,
264
+ prompt_chars: prompt?.length ?? 0,
265
+ max_tokens: maxTokens,
266
+ });
267
+
268
+ if (!model || !processor) {
269
+ log("error", "prompt_model_unavailable", { request_id: req.requestId });
270
+ return res.status(503).json({ detail: "Model not loaded yet." });
271
+ }
272
+ if (!prompt) {
273
+ log("error", "prompt_validation_failed", {
274
+ request_id: req.requestId,
275
+ reason: "missing_prompt",
276
+ });
277
+ return res.status(400).json({ detail: "No prompt provided." });
278
+ }
279
+
280
+ try {
281
+ const start = Date.now();
282
+ const response = await queueTextInference(prompt, maxTokens);
283
+ log("info", "prompt_completed", {
284
+ request_id: req.requestId,
285
+ duration_ms: Date.now() - start,
286
+ response_chars: response?.length ?? 0,
287
+ });
288
+ res.json({ response });
289
+ } catch (err) {
290
+ log("error", "prompt_failed", {
291
+ request_id: req.requestId,
292
+ error: err.message,
293
+ stack: err.stack,
294
+ });
295
+ res.status(500).json({ detail: "Inference failed.", error: err.message });
296
+ }
297
+ });
298
+
299
  app.post("/inference", upload.single("file"), async (req, res) => {
300
  const prompt = req.body.prompt || "Describe this image in detail.";
301
  const maxTokens = parseInt(req.body.max_tokens) || 256;