LH-Tech-AI's picture
Update server.js
5d2b34d verified
const express = require('express');
const ort = require('onnxruntime-node');
const tiktoken = require('js-tiktoken');
const cors = require('cors');
const path = require('path');
const app = express();
app.use(cors());
app.use(express.json());
const enc = tiktoken.getEncoding("gpt2");
let session = null;
async function initModel() {
console.log("--- DEBUG: File-Check ---");
const fs = require('fs');
try {
const modelPath = path.join(__dirname, 'SmaLLMPro_350M_int8.onnx'); // full model: SmaLLMPro_350M_Final.onnx
console.log("Searched model path:", modelPath);
if (fs.existsSync(modelPath)) {
session = await ort.InferenceSession.create(modelPath);
console.log("Model loaded :D!");
} else {
console.error("File not found!");
}
} catch (e) {
console.error("Error:", e.message);
}
}
initModel();
app.post('/chat', async (req, res) => {
if (!session) return res.status(503).json({ error: "Model loading ..." });
let clientConnected = true;
res.on('close', () => {
clientConnected = false;
console.log("Connection closed.");
});
const { prompt, temp, topK, maxLen, penalty } = req.body;
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
const formattedPrompt = `Instruction:\n${prompt}\n\nResponse:\n`;
let tokens = enc.encode(formattedPrompt);
const VOCAB_SIZE = 50304;
try {
for (let i = 0; i < maxLen; i++) {
if (!clientConnected) {
console.log("Inference stopped, because client disconnected.");
break;
}
const ctx = tokens.slice(-1024);
const paddedInput = new BigInt64Array(1024).fill(0n);
for (let i = 0; i < ctx.length; i++) {
paddedInput[1024 - ctx.length + i] = BigInt(ctx[i]);
}
const tensor = new ort.Tensor('int64', paddedInput, [1, 1024]);
const results = await session.run({ input: tensor });
const outputName = session.outputNames[0];
const logits = Array.from(results[outputName].data.slice(-VOCAB_SIZE));
if (penalty !== 1.0) {
for (const token of tokens) {
if (token < VOCAB_SIZE) {
if (logits[token] > 0) {
logits[token] /= penalty;
} else {
logits[token] *= penalty;
}
}
}
}
let scaledLogits = logits.map(l => l / temp);
const maxLogit = Math.max(...scaledLogits);
const exps = scaledLogits.map(l => Math.exp(l - maxLogit));
const sumExps = exps.reduce((a, b) => a + b, 0);
let probs = exps.map(e => e / sumExps);
let indexedProbs = probs.map((p, i) => ({ p, i }));
indexedProbs.sort((a, b) => b.p - a.p);
indexedProbs = indexedProbs.slice(0, topK);
const totalTopKProb = indexedProbs.reduce((a, b) => a + b.p, 0);
let r = Math.random() * totalTopKProb;
let nextToken = indexedProbs[0].i;
for (let pObj of indexedProbs) {
r -= pObj.p;
if (r <= 0) {
nextToken = pObj.i;
break;
}
}
if (nextToken === 50256) break; // EOS Token
tokens.push(nextToken);
const newText = enc.decode([nextToken]);
if (clientConnected) {
res.write(`data: ${JSON.stringify({ token: newText })}\n\n`);
}
await new Promise(r => setTimeout(r, 1));
}
} catch (err) {
console.error("Error:", err);
res.write(`data: ${JSON.stringify({ error: err.message })}\n\n`);
} finally {
res.end();
}
});
app.get('/', (req, res) => res.send("SmaLLMPro Backend is Running"));
app.listen(7860, '0.0.0.0');