File size: 6,691 Bytes
8b187bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
package com.minimind.mind2

import android.content.Context
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import java.io.File

/**
 * MiniMind (Mind2) Model Interface
 * Kotlin wrapper for native llama.cpp inference
 */
class Mind2Model private constructor() {

    companion object {
        init {
            System.loadLibrary("mind2")
        }

        private var instance: Mind2Model? = null

        @JvmStatic
        fun getInstance(): Mind2Model {
            return instance ?: synchronized(this) {
                instance ?: Mind2Model().also { instance = it }
            }
        }
    }

    // Model state
    private var isLoaded = false
    private var modelPath: String? = null

    // Generation parameters
    data class GenerationConfig(
        val maxTokens: Int = 256,
        val temperature: Float = 0.7f,
        val topP: Float = 0.9f,
        val topK: Int = 40,
        val repeatPenalty: Float = 1.1f,
        val stopTokens: List<String> = listOf("<|endoftext|>", "<|im_end|>")
    )

    /**
     * Load model from assets or file path
     */
    suspend fun load(
        context: Context,
        modelName: String = "mind2-lite.gguf",
        contextLength: Int = 2048,
        threads: Int = 0  // 0 = auto
    ): Result<Unit> = withContext(Dispatchers.IO) {
        try {
            // Check if model is in assets
            val assetPath = "models/$modelName"
            val modelFile = File(context.filesDir, modelName)

            if (!modelFile.exists()) {
                // Copy from assets
                context.assets.open(assetPath).use { input ->
                    modelFile.outputStream().use { output ->
                        input.copyTo(output)
                    }
                }
            }

            modelPath = modelFile.absolutePath
            val success = nativeInit(modelPath!!, contextLength, threads)

            if (success) {
                isLoaded = true
                Result.success(Unit)
            } else {
                Result.failure(RuntimeException("Failed to load model"))
            }
        } catch (e: Exception) {
            Result.failure(e)
        }
    }

    /**
     * Generate text (non-streaming)
     */
    suspend fun generate(
        prompt: String,
        config: GenerationConfig = GenerationConfig()
    ): Result<String> = withContext(Dispatchers.IO) {
        if (!isLoaded) {
            return@withContext Result.failure(IllegalStateException("Model not loaded"))
        }

        try {
            val result = nativeGenerate(
                prompt,
                config.maxTokens,
                config.temperature,
                config.topP,
                config.topK
            )
            Result.success(result)
        } catch (e: Exception) {
            Result.failure(e)
        }
    }

    /**
     * Generate text with streaming
     */
    fun generateStream(
        prompt: String,
        config: GenerationConfig = GenerationConfig()
    ): Flow<String> = callbackFlow {
        if (!isLoaded) {
            throw IllegalStateException("Model not loaded")
        }

        val callback = object : TokenCallback {
            override fun onToken(token: String) {
                trySend(token)
            }

            override fun onComplete() {
                channel.close()
            }
        }

        nativeGenerateStream(
            prompt,
            config.maxTokens,
            config.temperature,
            config.topP,
            config.topK,
            callback
        )

        awaitClose { stop() }
    }.flowOn(Dispatchers.IO)

    /**
     * Chat with conversation history
     */
    suspend fun chat(
        message: String,
        history: List<ChatMessage> = emptyList(),
        config: GenerationConfig = GenerationConfig()
    ): Result<String> {
        val prompt = buildChatPrompt(message, history)
        return generate(prompt, config)
    }

    /**
     * Chat with streaming
     */
    fun chatStream(
        message: String,
        history: List<ChatMessage> = emptyList(),
        config: GenerationConfig = GenerationConfig()
    ): Flow<String> {
        val prompt = buildChatPrompt(message, history)
        return generateStream(prompt, config)
    }

    private fun buildChatPrompt(message: String, history: List<ChatMessage>): String {
        val sb = StringBuilder()

        // System prompt
        sb.append("<|im_start|>system\n")
        sb.append("You are Mind2, a helpful AI assistant running locally on this device.\n")
        sb.append("<|im_end|>\n")

        // History
        for (msg in history) {
            sb.append("<|im_start|>${msg.role}\n")
            sb.append("${msg.content}\n")
            sb.append("<|im_end|>\n")
        }

        // Current message
        sb.append("<|im_start|>user\n")
        sb.append("$message\n")
        sb.append("<|im_end|>\n")
        sb.append("<|im_start|>assistant\n")

        return sb.toString()
    }

    /**
     * Stop ongoing generation
     */
    fun stop() {
        nativeStop()
    }

    /**
     * Release resources
     */
    fun release() {
        nativeRelease()
        isLoaded = false
        modelPath = null
    }

    /**
     * Get model info
     */
    fun getInfo(): String = nativeGetInfo()

    /**
     * Benchmark inference speed
     */
    suspend fun benchmark(tokens: Int = 100): Float = withContext(Dispatchers.IO) {
        nativeBenchmark(tokens)
    }

    // Native methods
    private external fun nativeInit(modelPath: String, nCtx: Int, nThreads: Int): Boolean
    private external fun nativeGenerate(
        prompt: String,
        maxTokens: Int,
        temperature: Float,
        topP: Float,
        topK: Int
    ): String
    private external fun nativeGenerateStream(
        prompt: String,
        maxTokens: Int,
        temperature: Float,
        topP: Float,
        topK: Int,
        callback: TokenCallback
    )
    private external fun nativeStop()
    private external fun nativeRelease()
    private external fun nativeGetInfo(): String
    private external fun nativeBenchmark(nTokens: Int): Float

    interface TokenCallback {
        fun onToken(token: String)
        fun onComplete()
    }

    data class ChatMessage(
        val role: String,  // "user" or "assistant"
        val content: String
    )
}

/**
 * Extension function for easy initialization
 */
suspend fun Context.loadMind2Model(
    modelName: String = "mind2-lite.gguf",
    contextLength: Int = 2048
): Result<Mind2Model> {
    val model = Mind2Model.getInstance()
    return model.load(this, modelName, contextLength).map { model }
}