File size: 4,202 Bytes
6a88a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d2b34d
6a88a56
bd00ea6
 
 
 
 
 
 
 
6a88a56
 
bd00ea6
6a88a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
const express = require('express');
const ort = require('onnxruntime-node');
const tiktoken = require('js-tiktoken');
const cors = require('cors');
const path = require('path');

const app = express();
app.use(cors());
app.use(express.json());

const enc = tiktoken.getEncoding("gpt2"); 
let session = null;

async function initModel() {
    console.log("--- DEBUG: File-Check ---");
    const fs = require('fs');
    try {
        const modelPath = path.join(__dirname, 'SmaLLMPro_350M_int8.onnx'); // full model: SmaLLMPro_350M_Final.onnx
        
        console.log("Searched model path:", modelPath);
        if (fs.existsSync(modelPath)) {
            session = await ort.InferenceSession.create(modelPath);
            console.log("Model loaded :D!");
        } else {
            console.error("File not found!");
        }
    } catch (e) {
        console.error("Error:", e.message);
    }
}
initModel();

app.post('/chat', async (req, res) => {
    if (!session) return res.status(503).json({ error: "Model loading ..." });

    let clientConnected = true; 

    res.on('close', () => {
        clientConnected = false;
        console.log("Connection closed.");
    });
  
    const { prompt, temp, topK, maxLen, penalty } = req.body;
    res.setHeader('Content-Type', 'text/event-stream');
    res.setHeader('Cache-Control', 'no-cache');

    const formattedPrompt = `Instruction:\n${prompt}\n\nResponse:\n`;
    let tokens = enc.encode(formattedPrompt);

    const VOCAB_SIZE = 50304; 

    try {
        for (let i = 0; i < maxLen; i++) {
            if (!clientConnected) {
                console.log("Inference stopped, because client disconnected.");
                break; 
            }
          
            const ctx = tokens.slice(-1024);

            const paddedInput = new BigInt64Array(1024).fill(0n);
            
            for (let i = 0; i < ctx.length; i++) {
                paddedInput[1024 - ctx.length + i] = BigInt(ctx[i]);
            }
            
            const tensor = new ort.Tensor('int64', paddedInput, [1, 1024]);
            
            const results = await session.run({ input: tensor });
            
            const outputName = session.outputNames[0];
            const logits = Array.from(results[outputName].data.slice(-VOCAB_SIZE));

            if (penalty !== 1.0) {
                for (const token of tokens) {
                    if (token < VOCAB_SIZE) {
                        if (logits[token] > 0) {
                            logits[token] /= penalty;
                        } else {
                            logits[token] *= penalty;
                        }
                    }
                }
            }

            let scaledLogits = logits.map(l => l / temp);
            const maxLogit = Math.max(...scaledLogits);
            const exps = scaledLogits.map(l => Math.exp(l - maxLogit));
            const sumExps = exps.reduce((a, b) => a + b, 0);
            let probs = exps.map(e => e / sumExps);

            let indexedProbs = probs.map((p, i) => ({ p, i }));
            indexedProbs.sort((a, b) => b.p - a.p);
            indexedProbs = indexedProbs.slice(0, topK);

            const totalTopKProb = indexedProbs.reduce((a, b) => a + b.p, 0);
            let r = Math.random() * totalTopKProb;
            let nextToken = indexedProbs[0].i;
            
            for (let pObj of indexedProbs) {
                r -= pObj.p;
                if (r <= 0) {
                    nextToken = pObj.i;
                    break;
                }
            }

            if (nextToken === 50256) break; // EOS Token

            tokens.push(nextToken);
            
            const newText = enc.decode([nextToken]);
          
            if (clientConnected) {
                res.write(`data: ${JSON.stringify({ token: newText })}\n\n`);
            }

            await new Promise(r => setTimeout(r, 1));
        }
    } catch (err) {
        console.error("Error:", err);
        res.write(`data: ${JSON.stringify({ error: err.message })}\n\n`);
    } finally {
        res.end();
    }
});

app.get('/', (req, res) => res.send("SmaLLMPro Backend is Running"));
app.listen(7860, '0.0.0.0');