|
|
package com.minimind.mind2 |
|
|
|
|
|
import android.content.Context |
|
|
import kotlinx.coroutines.* |
|
|
import kotlinx.coroutines.flow.* |
|
|
import java.io.File |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mind2Model private constructor() { |
|
|
|
|
|
companion object { |
|
|
init { |
|
|
System.loadLibrary("mind2") |
|
|
} |
|
|
|
|
|
private var instance: Mind2Model? = null |
|
|
|
|
|
@JvmStatic |
|
|
fun getInstance(): Mind2Model { |
|
|
return instance ?: synchronized(this) { |
|
|
instance ?: Mind2Model().also { instance = it } |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private var isLoaded = false |
|
|
private var modelPath: String? = null |
|
|
|
|
|
|
|
|
data class GenerationConfig( |
|
|
val maxTokens: Int = 256, |
|
|
val temperature: Float = 0.7f, |
|
|
val topP: Float = 0.9f, |
|
|
val topK: Int = 40, |
|
|
val repeatPenalty: Float = 1.1f, |
|
|
val stopTokens: List<String> = listOf("<|endoftext|>", "<|im_end|>") |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
suspend fun load( |
|
|
context: Context, |
|
|
modelName: String = "mind2-lite.gguf", |
|
|
contextLength: Int = 2048, |
|
|
threads: Int = 0 |
|
|
): Result<Unit> = withContext(Dispatchers.IO) { |
|
|
try { |
|
|
|
|
|
val assetPath = "models/$modelName" |
|
|
val modelFile = File(context.filesDir, modelName) |
|
|
|
|
|
if (!modelFile.exists()) { |
|
|
|
|
|
context.assets.open(assetPath).use { input -> |
|
|
modelFile.outputStream().use { output -> |
|
|
input.copyTo(output) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
modelPath = modelFile.absolutePath |
|
|
val success = nativeInit(modelPath!!, contextLength, threads) |
|
|
|
|
|
if (success) { |
|
|
isLoaded = true |
|
|
Result.success(Unit) |
|
|
} else { |
|
|
Result.failure(RuntimeException("Failed to load model")) |
|
|
} |
|
|
} catch (e: Exception) { |
|
|
Result.failure(e) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
suspend fun generate( |
|
|
prompt: String, |
|
|
config: GenerationConfig = GenerationConfig() |
|
|
): Result<String> = withContext(Dispatchers.IO) { |
|
|
if (!isLoaded) { |
|
|
return@withContext Result.failure(IllegalStateException("Model not loaded")) |
|
|
} |
|
|
|
|
|
try { |
|
|
val result = nativeGenerate( |
|
|
prompt, |
|
|
config.maxTokens, |
|
|
config.temperature, |
|
|
config.topP, |
|
|
config.topK |
|
|
) |
|
|
Result.success(result) |
|
|
} catch (e: Exception) { |
|
|
Result.failure(e) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fun generateStream( |
|
|
prompt: String, |
|
|
config: GenerationConfig = GenerationConfig() |
|
|
): Flow<String> = callbackFlow { |
|
|
if (!isLoaded) { |
|
|
throw IllegalStateException("Model not loaded") |
|
|
} |
|
|
|
|
|
val callback = object : TokenCallback { |
|
|
override fun onToken(token: String) { |
|
|
trySend(token) |
|
|
} |
|
|
|
|
|
override fun onComplete() { |
|
|
channel.close() |
|
|
} |
|
|
} |
|
|
|
|
|
nativeGenerateStream( |
|
|
prompt, |
|
|
config.maxTokens, |
|
|
config.temperature, |
|
|
config.topP, |
|
|
config.topK, |
|
|
callback |
|
|
) |
|
|
|
|
|
awaitClose { stop() } |
|
|
}.flowOn(Dispatchers.IO) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
suspend fun chat( |
|
|
message: String, |
|
|
history: List<ChatMessage> = emptyList(), |
|
|
config: GenerationConfig = GenerationConfig() |
|
|
): Result<String> { |
|
|
val prompt = buildChatPrompt(message, history) |
|
|
return generate(prompt, config) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fun chatStream( |
|
|
message: String, |
|
|
history: List<ChatMessage> = emptyList(), |
|
|
config: GenerationConfig = GenerationConfig() |
|
|
): Flow<String> { |
|
|
val prompt = buildChatPrompt(message, history) |
|
|
return generateStream(prompt, config) |
|
|
} |
|
|
|
|
|
private fun buildChatPrompt(message: String, history: List<ChatMessage>): String { |
|
|
val sb = StringBuilder() |
|
|
|
|
|
|
|
|
sb.append("<|im_start|>system\n") |
|
|
sb.append("You are Mind2, a helpful AI assistant running locally on this device.\n") |
|
|
sb.append("<|im_end|>\n") |
|
|
|
|
|
|
|
|
for (msg in history) { |
|
|
sb.append("<|im_start|>${msg.role}\n") |
|
|
sb.append("${msg.content}\n") |
|
|
sb.append("<|im_end|>\n") |
|
|
} |
|
|
|
|
|
|
|
|
sb.append("<|im_start|>user\n") |
|
|
sb.append("$message\n") |
|
|
sb.append("<|im_end|>\n") |
|
|
sb.append("<|im_start|>assistant\n") |
|
|
|
|
|
return sb.toString() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fun stop() { |
|
|
nativeStop() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fun release() { |
|
|
nativeRelease() |
|
|
isLoaded = false |
|
|
modelPath = null |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fun getInfo(): String = nativeGetInfo() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
suspend fun benchmark(tokens: Int = 100): Float = withContext(Dispatchers.IO) { |
|
|
nativeBenchmark(tokens) |
|
|
} |
|
|
|
|
|
|
|
|
private external fun nativeInit(modelPath: String, nCtx: Int, nThreads: Int): Boolean |
|
|
private external fun nativeGenerate( |
|
|
prompt: String, |
|
|
maxTokens: Int, |
|
|
temperature: Float, |
|
|
topP: Float, |
|
|
topK: Int |
|
|
): String |
|
|
private external fun nativeGenerateStream( |
|
|
prompt: String, |
|
|
maxTokens: Int, |
|
|
temperature: Float, |
|
|
topP: Float, |
|
|
topK: Int, |
|
|
callback: TokenCallback |
|
|
) |
|
|
private external fun nativeStop() |
|
|
private external fun nativeRelease() |
|
|
private external fun nativeGetInfo(): String |
|
|
private external fun nativeBenchmark(nTokens: Int): Float |
|
|
|
|
|
interface TokenCallback { |
|
|
fun onToken(token: String) |
|
|
fun onComplete() |
|
|
} |
|
|
|
|
|
data class ChatMessage( |
|
|
val role: String, |
|
|
val content: String |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
suspend fun Context.loadMind2Model( |
|
|
modelName: String = "mind2-lite.gguf", |
|
|
contextLength: Int = 2048 |
|
|
): Result<Mind2Model> { |
|
|
val model = Mind2Model.getInstance() |
|
|
return model.load(this, modelName, contextLength).map { model } |
|
|
} |
|
|
|