const express = require('express'); const ort = require('onnxruntime-node'); const tiktoken = require('js-tiktoken'); const cors = require('cors'); const path = require('path'); const app = express(); app.use(cors()); app.use(express.json()); const enc = tiktoken.getEncoding("gpt2"); let session = null; async function initModel() { console.log("--- DEBUG: File-Check ---"); const fs = require('fs'); try { const modelPath = path.join(__dirname, 'SmaLLMPro_350M_int8.onnx'); // full model: SmaLLMPro_350M_Final.onnx console.log("Searched model path:", modelPath); if (fs.existsSync(modelPath)) { session = await ort.InferenceSession.create(modelPath); console.log("Model loaded :D!"); } else { console.error("File not found!"); } } catch (e) { console.error("Error:", e.message); } } initModel(); app.post('/chat', async (req, res) => { if (!session) return res.status(503).json({ error: "Model loading ..." }); let clientConnected = true; res.on('close', () => { clientConnected = false; console.log("Connection closed."); }); const { prompt, temp, topK, maxLen, penalty } = req.body; res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache'); const formattedPrompt = `Instruction:\n${prompt}\n\nResponse:\n`; let tokens = enc.encode(formattedPrompt); const VOCAB_SIZE = 50304; try { for (let i = 0; i < maxLen; i++) { if (!clientConnected) { console.log("Inference stopped, because client disconnected."); break; } const ctx = tokens.slice(-1024); const paddedInput = new BigInt64Array(1024).fill(0n); for (let i = 0; i < ctx.length; i++) { paddedInput[1024 - ctx.length + i] = BigInt(ctx[i]); } const tensor = new ort.Tensor('int64', paddedInput, [1, 1024]); const results = await session.run({ input: tensor }); const outputName = session.outputNames[0]; const logits = Array.from(results[outputName].data.slice(-VOCAB_SIZE)); if (penalty !== 1.0) { for (const token of tokens) { if (token < VOCAB_SIZE) { if (logits[token] > 0) { logits[token] /= penalty; } else { logits[token] *= penalty; } } } } let scaledLogits = logits.map(l => l / temp); const maxLogit = Math.max(...scaledLogits); const exps = scaledLogits.map(l => Math.exp(l - maxLogit)); const sumExps = exps.reduce((a, b) => a + b, 0); let probs = exps.map(e => e / sumExps); let indexedProbs = probs.map((p, i) => ({ p, i })); indexedProbs.sort((a, b) => b.p - a.p); indexedProbs = indexedProbs.slice(0, topK); const totalTopKProb = indexedProbs.reduce((a, b) => a + b.p, 0); let r = Math.random() * totalTopKProb; let nextToken = indexedProbs[0].i; for (let pObj of indexedProbs) { r -= pObj.p; if (r <= 0) { nextToken = pObj.i; break; } } if (nextToken === 50256) break; // EOS Token tokens.push(nextToken); const newText = enc.decode([nextToken]); if (clientConnected) { res.write(`data: ${JSON.stringify({ token: newText })}\n\n`); } await new Promise(r => setTimeout(r, 1)); } } catch (err) { console.error("Error:", err); res.write(`data: ${JSON.stringify({ error: err.message })}\n\n`); } finally { res.end(); } }); app.get('/', (req, res) => res.send("SmaLLMPro Backend is Running")); app.listen(7860, '0.0.0.0');