File size: 4,202 Bytes
6a88a56 5d2b34d 6a88a56 bd00ea6 6a88a56 bd00ea6 6a88a56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | const express = require('express');
const ort = require('onnxruntime-node');
const tiktoken = require('js-tiktoken');
const cors = require('cors');
const path = require('path');
const app = express();
app.use(cors());
app.use(express.json());
const enc = tiktoken.getEncoding("gpt2");
let session = null;
async function initModel() {
console.log("--- DEBUG: File-Check ---");
const fs = require('fs');
try {
const modelPath = path.join(__dirname, 'SmaLLMPro_350M_int8.onnx'); // full model: SmaLLMPro_350M_Final.onnx
console.log("Searched model path:", modelPath);
if (fs.existsSync(modelPath)) {
session = await ort.InferenceSession.create(modelPath);
console.log("Model loaded :D!");
} else {
console.error("File not found!");
}
} catch (e) {
console.error("Error:", e.message);
}
}
initModel();
app.post('/chat', async (req, res) => {
if (!session) return res.status(503).json({ error: "Model loading ..." });
let clientConnected = true;
res.on('close', () => {
clientConnected = false;
console.log("Connection closed.");
});
const { prompt, temp, topK, maxLen, penalty } = req.body;
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
const formattedPrompt = `Instruction:\n${prompt}\n\nResponse:\n`;
let tokens = enc.encode(formattedPrompt);
const VOCAB_SIZE = 50304;
try {
for (let i = 0; i < maxLen; i++) {
if (!clientConnected) {
console.log("Inference stopped, because client disconnected.");
break;
}
const ctx = tokens.slice(-1024);
const paddedInput = new BigInt64Array(1024).fill(0n);
for (let i = 0; i < ctx.length; i++) {
paddedInput[1024 - ctx.length + i] = BigInt(ctx[i]);
}
const tensor = new ort.Tensor('int64', paddedInput, [1, 1024]);
const results = await session.run({ input: tensor });
const outputName = session.outputNames[0];
const logits = Array.from(results[outputName].data.slice(-VOCAB_SIZE));
if (penalty !== 1.0) {
for (const token of tokens) {
if (token < VOCAB_SIZE) {
if (logits[token] > 0) {
logits[token] /= penalty;
} else {
logits[token] *= penalty;
}
}
}
}
let scaledLogits = logits.map(l => l / temp);
const maxLogit = Math.max(...scaledLogits);
const exps = scaledLogits.map(l => Math.exp(l - maxLogit));
const sumExps = exps.reduce((a, b) => a + b, 0);
let probs = exps.map(e => e / sumExps);
let indexedProbs = probs.map((p, i) => ({ p, i }));
indexedProbs.sort((a, b) => b.p - a.p);
indexedProbs = indexedProbs.slice(0, topK);
const totalTopKProb = indexedProbs.reduce((a, b) => a + b.p, 0);
let r = Math.random() * totalTopKProb;
let nextToken = indexedProbs[0].i;
for (let pObj of indexedProbs) {
r -= pObj.p;
if (r <= 0) {
nextToken = pObj.i;
break;
}
}
if (nextToken === 50256) break; // EOS Token
tokens.push(nextToken);
const newText = enc.decode([nextToken]);
if (clientConnected) {
res.write(`data: ${JSON.stringify({ token: newText })}\n\n`);
}
await new Promise(r => setTimeout(r, 1));
}
} catch (err) {
console.error("Error:", err);
res.write(`data: ${JSON.stringify({ error: err.message })}\n\n`);
} finally {
res.end();
}
});
app.get('/', (req, res) => res.send("SmaLLMPro Backend is Running"));
app.listen(7860, '0.0.0.0');
|