| | const express = require('express'); |
| | const ort = require('onnxruntime-node'); |
| | const tiktoken = require('js-tiktoken'); |
| | const cors = require('cors'); |
| | const path = require('path'); |
| |
|
| | const app = express(); |
| | app.use(cors()); |
| | app.use(express.json()); |
| |
|
| | const enc = tiktoken.getEncoding("gpt2"); |
| | let session = null; |
| |
|
| | async function initModel() { |
| | console.log("--- DEBUG: File-Check ---"); |
| | const fs = require('fs'); |
| | try { |
| | const modelPath = path.join(__dirname, 'SmaLLMPro_350M_int8.onnx'); |
| | |
| | console.log("Searched model path:", modelPath); |
| | if (fs.existsSync(modelPath)) { |
| | session = await ort.InferenceSession.create(modelPath); |
| | console.log("Model loaded :D!"); |
| | } else { |
| | console.error("File not found!"); |
| | } |
| | } catch (e) { |
| | console.error("Error:", e.message); |
| | } |
| | } |
| | initModel(); |
| |
|
| | app.post('/chat', async (req, res) => { |
| | if (!session) return res.status(503).json({ error: "Model loading ..." }); |
| |
|
| | let clientConnected = true; |
| |
|
| | res.on('close', () => { |
| | clientConnected = false; |
| | console.log("Connection closed."); |
| | }); |
| | |
| | const { prompt, temp, topK, maxLen, penalty } = req.body; |
| | res.setHeader('Content-Type', 'text/event-stream'); |
| | res.setHeader('Cache-Control', 'no-cache'); |
| |
|
| | const formattedPrompt = `Instruction:\n${prompt}\n\nResponse:\n`; |
| | let tokens = enc.encode(formattedPrompt); |
| |
|
| | const VOCAB_SIZE = 50304; |
| |
|
| | try { |
| | for (let i = 0; i < maxLen; i++) { |
| | if (!clientConnected) { |
| | console.log("Inference stopped, because client disconnected."); |
| | break; |
| | } |
| | |
| | const ctx = tokens.slice(-1024); |
| |
|
| | const paddedInput = new BigInt64Array(1024).fill(0n); |
| | |
| | for (let i = 0; i < ctx.length; i++) { |
| | paddedInput[1024 - ctx.length + i] = BigInt(ctx[i]); |
| | } |
| | |
| | const tensor = new ort.Tensor('int64', paddedInput, [1, 1024]); |
| | |
| | const results = await session.run({ input: tensor }); |
| | |
| | const outputName = session.outputNames[0]; |
| | const logits = Array.from(results[outputName].data.slice(-VOCAB_SIZE)); |
| |
|
| | if (penalty !== 1.0) { |
| | for (const token of tokens) { |
| | if (token < VOCAB_SIZE) { |
| | if (logits[token] > 0) { |
| | logits[token] /= penalty; |
| | } else { |
| | logits[token] *= penalty; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | let scaledLogits = logits.map(l => l / temp); |
| | const maxLogit = Math.max(...scaledLogits); |
| | const exps = scaledLogits.map(l => Math.exp(l - maxLogit)); |
| | const sumExps = exps.reduce((a, b) => a + b, 0); |
| | let probs = exps.map(e => e / sumExps); |
| |
|
| | let indexedProbs = probs.map((p, i) => ({ p, i })); |
| | indexedProbs.sort((a, b) => b.p - a.p); |
| | indexedProbs = indexedProbs.slice(0, topK); |
| |
|
| | const totalTopKProb = indexedProbs.reduce((a, b) => a + b.p, 0); |
| | let r = Math.random() * totalTopKProb; |
| | let nextToken = indexedProbs[0].i; |
| | |
| | for (let pObj of indexedProbs) { |
| | r -= pObj.p; |
| | if (r <= 0) { |
| | nextToken = pObj.i; |
| | break; |
| | } |
| | } |
| |
|
| | if (nextToken === 50256) break; |
| |
|
| | tokens.push(nextToken); |
| | |
| | const newText = enc.decode([nextToken]); |
| | |
| | if (clientConnected) { |
| | res.write(`data: ${JSON.stringify({ token: newText })}\n\n`); |
| | } |
| |
|
| | await new Promise(r => setTimeout(r, 1)); |
| | } |
| | } catch (err) { |
| | console.error("Error:", err); |
| | res.write(`data: ${JSON.stringify({ error: err.message })}\n\n`); |
| | } finally { |
| | res.end(); |
| | } |
| | }); |
| |
|
| | app.get('/', (req, res) => res.send("SmaLLMPro Backend is Running")); |
| | app.listen(7860, '0.0.0.0'); |
| |
|