|
|
const fs = require('fs'); |
|
|
const path = require('path'); |
|
|
const { |
|
|
withAppBuildGradle, |
|
|
withDangerousMod, |
|
|
withMainApplication, |
|
|
} = require('@expo/config-plugins'); |
|
|
|
|
|
const ORT_DEP = "implementation(\"com.microsoft.onnxruntime:onnxruntime-android:1.18.0\")"; |
|
|
|
|
|
|
|
|
const IDENTITY_ONNX_B64 = |
|
|
"CAcSFXN0YXRpY3BsYXktc2QxLW1vYmlsZTpWChoKAXgSAXkaCElkZW50aXR5IghJZGVudGl0eRIOaWRlbnRpdHlfZ3JhcGhaEwoBeBIOCgwIARIICgIIAQoCCANiEwoBeRIOCgwIARIICgIIAQoCCANCBAoAEAs="; |
|
|
|
|
|
function upsertOnce(text, needle, insertAfterRegex, insertion) { |
|
|
if (text.includes(needle)) return text; |
|
|
const match = text.match(insertAfterRegex); |
|
|
if (!match || match.index == null) return text + '\n' + insertion + '\n'; |
|
|
const idx = match.index + match[0].length; |
|
|
return text.slice(0, idx) + '\n' + insertion + text.slice(idx); |
|
|
} |
|
|
|
|
|
function withOrtNative(config) { |
|
|
config = withAppBuildGradle(config, (cfg) => { |
|
|
if (!cfg.modResults.contents.includes(ORT_DEP)) { |
|
|
cfg.modResults.contents = upsertOnce( |
|
|
cfg.modResults.contents, |
|
|
ORT_DEP, |
|
|
/dependencies\s*\{\s*/m, |
|
|
` ${ORT_DEP}` |
|
|
); |
|
|
} |
|
|
return cfg; |
|
|
}); |
|
|
|
|
|
config = withMainApplication(config, (cfg) => { |
|
|
let contents = cfg.modResults.contents; |
|
|
|
|
|
if (!contents.includes('import com.anonymous.staticplaysd1mobile.ort.OrtPackage')) { |
|
|
contents = upsertOnce( |
|
|
contents, |
|
|
'import com.anonymous.staticplaysd1mobile.ort.OrtPackage', |
|
|
/^package\s+.+\n/m, |
|
|
'import com.anonymous.staticplaysd1mobile.ort.OrtPackage' |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
if (!contents.includes('add(OrtPackage())')) { |
|
|
contents = contents.replace( |
|
|
/\/\/ add\(MyReactNativePackage\(\)\)\s*\n/, |
|
|
`// add(MyReactNativePackage())\n add(OrtPackage())\n` |
|
|
); |
|
|
} |
|
|
|
|
|
cfg.modResults.contents = contents; |
|
|
return cfg; |
|
|
}); |
|
|
|
|
|
config = withDangerousMod(config, [ |
|
|
'android', |
|
|
async (cfg) => { |
|
|
const projectRoot = cfg.modRequest.projectRoot; |
|
|
const androidDir = path.join(projectRoot, 'android'); |
|
|
|
|
|
|
|
|
const pkgDir = path.join( |
|
|
androidDir, |
|
|
'app', |
|
|
'src', |
|
|
'main', |
|
|
'java', |
|
|
'com', |
|
|
'anonymous', |
|
|
'staticplaysd1mobile', |
|
|
'ort' |
|
|
); |
|
|
fs.mkdirSync(pkgDir, { recursive: true }); |
|
|
|
|
|
const modulePath = path.join(pkgDir, 'OrtModule.kt'); |
|
|
const packagePath = path.join(pkgDir, 'OrtPackage.kt'); |
|
|
|
|
|
fs.writeFileSync( |
|
|
modulePath, |
|
|
`package com.anonymous.staticplaysd1mobile.ort |
|
|
|
|
|
import ai.onnxruntime.OnnxTensor |
|
|
import ai.onnxruntime.OrtEnvironment |
|
|
import ai.onnxruntime.OrtSession |
|
|
import ai.onnxruntime.OrtSession.SessionOptions |
|
|
import ai.onnxruntime.TensorInfo |
|
|
import com.facebook.react.bridge.Arguments |
|
|
import com.facebook.react.bridge.Promise |
|
|
import com.facebook.react.bridge.ReadableArray |
|
|
import com.facebook.react.bridge.ReadableMap |
|
|
import com.facebook.react.bridge.ReactApplicationContext |
|
|
import com.facebook.react.bridge.ReactContextBaseJavaModule |
|
|
import com.facebook.react.bridge.ReactMethod |
|
|
import android.net.Uri |
|
|
import android.util.Base64 |
|
|
import android.graphics.Bitmap |
|
|
import java.io.File |
|
|
import java.io.FileOutputStream |
|
|
import java.nio.LongBuffer |
|
|
import java.nio.FloatBuffer |
|
|
import java.security.MessageDigest |
|
|
import java.util.Random |
|
|
|
|
|
class OrtModule(reactContext: ReactApplicationContext) : ReactContextBaseJavaModule(reactContext) { |
|
|
override fun getName(): String = "OrtNative" |
|
|
|
|
|
private var env: OrtEnvironment? = null |
|
|
private var sessionNnapi: OrtSession? = null |
|
|
private var sessionCpu: OrtSession? = null |
|
|
private var packRoot: String? = null |
|
|
private var textCpu: OrtSession? = null |
|
|
private var vaeCpu: OrtSession? = null |
|
|
private var unetNnapi: OrtSession? = null |
|
|
private var unetCpu: OrtSession? = null |
|
|
|
|
|
private val identityModelB64: String = |
|
|
"${IDENTITY_ONNX_B64}" |
|
|
|
|
|
private fun flattenToFloatArray(value: Any?): FloatArray { |
|
|
if (value == null) return FloatArray(0) |
|
|
return when (value) { |
|
|
is FloatArray -> value |
|
|
is FloatBuffer -> { |
|
|
val buf = value.duplicate() |
|
|
val out = FloatArray(buf.remaining()) |
|
|
buf.get(out) |
|
|
out |
|
|
} |
|
|
is Array<*> -> { |
|
|
val out = ArrayList<Float>() |
|
|
fun addAny(v: Any?) { |
|
|
when (v) { |
|
|
null -> {} |
|
|
is Number -> out.add(v.toFloat()) |
|
|
is FloatArray -> for (x in v) out.add(x) |
|
|
is FloatBuffer -> { |
|
|
val b = v.duplicate() |
|
|
while (b.hasRemaining()) out.add(b.get()) |
|
|
} |
|
|
is Array<*> -> for (e in v) addAny(e) |
|
|
else -> throw IllegalStateException("Unsupported output element type: \${v.javaClass.name}") |
|
|
} |
|
|
} |
|
|
for (e in value) addAny(e) |
|
|
out.toFloatArray() |
|
|
} |
|
|
else -> throw IllegalStateException("Unsupported output type: \${value.javaClass.name}") |
|
|
} |
|
|
} |
|
|
|
|
|
private fun ensureModelFile(): File { |
|
|
val outFile = File(reactApplicationContext.cacheDir, "identity.onnx") |
|
|
// Always overwrite to avoid stale/cached models after app updates. |
|
|
if (outFile.exists()) outFile.delete() |
|
|
val bytes = Base64.decode(identityModelB64, Base64.DEFAULT) |
|
|
FileOutputStream(outFile).use { it.write(bytes) } |
|
|
return outFile |
|
|
} |
|
|
|
|
|
private fun sha256Hex(file: File): String { |
|
|
val md = MessageDigest.getInstance("SHA-256") |
|
|
file.inputStream().use { input -> |
|
|
val buf = ByteArray(1024 * 64) |
|
|
while (true) { |
|
|
val r = input.read(buf) |
|
|
if (r <= 0) break |
|
|
md.update(buf, 0, r) |
|
|
} |
|
|
} |
|
|
return md.digest().joinToString("") { "%02x".format(it) } |
|
|
} |
|
|
|
|
|
private fun uriToPath(uri: String): String { |
|
|
return if (uri.startsWith("file://")) uri.removePrefix("file://") else uri |
|
|
} |
|
|
|
|
|
private fun ensureDir(dir: File) { |
|
|
if (!dir.exists()) { |
|
|
dir.mkdirs() |
|
|
} |
|
|
} |
|
|
|
|
|
private fun isSubPath(base: File, child: File): Boolean { |
|
|
val basePath = base.canonicalFile.toPath() |
|
|
val childPath = child.canonicalFile.toPath() |
|
|
return childPath.startsWith(basePath) |
|
|
} |
|
|
|
|
|
private fun getEnv(): OrtEnvironment { |
|
|
val existing = env |
|
|
if (existing != null) return existing |
|
|
val created = OrtEnvironment.getEnvironment() |
|
|
env = created |
|
|
return created |
|
|
} |
|
|
|
|
|
private fun getSession(preferNnapi: Boolean): Pair<OrtSession, String> { |
|
|
val modelFile = ensureModelFile() |
|
|
val env = getEnv() |
|
|
|
|
|
if (preferNnapi) { |
|
|
val existing = sessionNnapi |
|
|
if (existing != null) return existing to "nnapi" |
|
|
try { |
|
|
val opts = SessionOptions() |
|
|
try { |
|
|
opts.addNnapi() |
|
|
} catch (_: Throwable) { |
|
|
// If this ORT build doesn't ship NNAPI, session creation will fall back below. |
|
|
} |
|
|
val created = env.createSession(modelFile.absolutePath, opts) |
|
|
sessionNnapi = created |
|
|
return created to "nnapi" |
|
|
} catch (_e: Throwable) { |
|
|
// Fall back to CPU |
|
|
} |
|
|
} |
|
|
|
|
|
val existingCpu = sessionCpu |
|
|
if (existingCpu != null) return existingCpu to "cpu" |
|
|
try { |
|
|
val cpu = env.createSession(modelFile.absolutePath, SessionOptions()) |
|
|
sessionCpu = cpu |
|
|
return cpu to "cpu" |
|
|
} catch (e: Throwable) { |
|
|
val size = try { modelFile.length() } catch (_: Throwable) { -1L } |
|
|
val sha = try { sha256Hex(modelFile).take(16) } catch (_: Throwable) { "?" } |
|
|
throw IllegalStateException("Failed to load identity.onnx (size=$size sha256=$sha) at \${modelFile.absolutePath}: \${e.message}", e) |
|
|
} |
|
|
} |
|
|
|
|
|
private fun resetPackIfNeeded(packDir: String) { |
|
|
if (packRoot == packDir) return |
|
|
packRoot = packDir |
|
|
try { textCpu?.close() } catch (_: Throwable) {} |
|
|
try { vaeCpu?.close() } catch (_: Throwable) {} |
|
|
try { unetNnapi?.close() } catch (_: Throwable) {} |
|
|
try { unetCpu?.close() } catch (_: Throwable) {} |
|
|
textCpu = null |
|
|
vaeCpu = null |
|
|
unetNnapi = null |
|
|
unetCpu = null |
|
|
} |
|
|
|
|
|
private fun getTextEncoder(packDir: String): OrtSession { |
|
|
resetPackIfNeeded(packDir) |
|
|
val existing = textCpu |
|
|
if (existing != null) return existing |
|
|
val env = getEnv() |
|
|
val path = File(packDir, "text_encoder.onnx").absolutePath |
|
|
val s = env.createSession(path, SessionOptions()) |
|
|
textCpu = s |
|
|
return s |
|
|
} |
|
|
|
|
|
private fun getVaeDecoder(packDir: String): OrtSession { |
|
|
resetPackIfNeeded(packDir) |
|
|
val existing = vaeCpu |
|
|
if (existing != null) return existing |
|
|
val env = getEnv() |
|
|
val path = File(packDir, "vae_decoder.onnx").absolutePath |
|
|
val s = env.createSession(path, SessionOptions()) |
|
|
vaeCpu = s |
|
|
return s |
|
|
} |
|
|
|
|
|
private fun getUnet(packDir: String, preferNnapi: Boolean): Pair<OrtSession, String> { |
|
|
resetPackIfNeeded(packDir) |
|
|
val env = getEnv() |
|
|
val modelPath = File(packDir, "unet.onnx").absolutePath |
|
|
|
|
|
val existingNnapi = unetNnapi |
|
|
val existingCpu = unetCpu |
|
|
if (!preferNnapi) { |
|
|
if (existingCpu != null) return existingCpu to "cpu" |
|
|
} else { |
|
|
if (existingNnapi != null) return existingNnapi to "nnapi" |
|
|
if (existingCpu != null) return existingCpu to "cpu" |
|
|
} |
|
|
|
|
|
if (!preferNnapi) { |
|
|
val s = env.createSession(modelPath, SessionOptions()) |
|
|
unetCpu = s |
|
|
return s to "cpu" |
|
|
} |
|
|
|
|
|
try { |
|
|
val opts = SessionOptions() |
|
|
try { opts.addNnapi() } catch (_: Throwable) {} |
|
|
val s = env.createSession(modelPath, opts) |
|
|
unetNnapi = s |
|
|
return s to "nnapi" |
|
|
} catch (_: Throwable) { |
|
|
val s = env.createSession(modelPath, SessionOptions()) |
|
|
unetCpu = s |
|
|
return s to "cpu" |
|
|
} |
|
|
} |
|
|
|
|
|
private fun readInt64Ids(ids: ReadableArray, expected: Int): LongArray { |
|
|
if (ids.size() != expected) throw IllegalArgumentException("Expected $expected ids, got \${ids.size()}") |
|
|
return LongArray(expected) { i -> ids.getDouble(i).toLong() } |
|
|
} |
|
|
|
|
|
private fun gaussianLatents(rng: Random, count: Int): FloatArray { |
|
|
val out = FloatArray(count) |
|
|
var i = 0 |
|
|
while (i < count) { |
|
|
// Box-Muller via nextGaussian |
|
|
out[i] = rng.nextGaussian().toFloat() |
|
|
i += 1 |
|
|
} |
|
|
return out |
|
|
} |
|
|
|
|
|
private fun alphasCumprod1000(): DoubleArray { |
|
|
val betas = DoubleArray(1000) |
|
|
val betaStart = 0.00085 |
|
|
val betaEnd = 0.012 |
|
|
val s0 = kotlin.math.sqrt(betaStart) |
|
|
val s1 = kotlin.math.sqrt(betaEnd) |
|
|
for (i in 0 until 1000) { |
|
|
val t = i.toDouble() / 999.0 |
|
|
val s = s0 + (s1 - s0) * t |
|
|
betas[i] = s * s |
|
|
} |
|
|
val alphasCum = DoubleArray(1000) |
|
|
var prod = 1.0 |
|
|
for (i in 0 until 1000) { |
|
|
val alpha = 1.0 - betas[i] |
|
|
prod *= alpha |
|
|
alphasCum[i] = prod |
|
|
} |
|
|
return alphasCum |
|
|
} |
|
|
|
|
|
private fun ddimTimesteps(steps: Int): IntArray { |
|
|
val stride = 1000 / steps |
|
|
return IntArray(steps) { i -> 999 - i * stride } |
|
|
} |
|
|
|
|
|
private fun floatArraySub(a: FloatArray, b: FloatArray, out: FloatArray) { |
|
|
for (i in out.indices) out[i] = a[i] - b[i] |
|
|
} |
|
|
private fun floatArrayAddScaled(a: FloatArray, b: FloatArray, scale: Float, out: FloatArray) { |
|
|
for (i in out.indices) out[i] = a[i] + (b[i] * scale) |
|
|
} |
|
|
private fun floatArrayMulScalar(a: FloatArray, s: Float, out: FloatArray) { |
|
|
for (i in out.indices) out[i] = a[i] * s |
|
|
} |
|
|
|
|
|
private fun decodeToPng(packDir: String, latents: FloatArray, outPath: String) { |
|
|
val vae = getVaeDecoder(packDir) |
|
|
val env = getEnv() |
|
|
val scale = 1.0f / 0.18215f |
|
|
val scaled = FloatArray(latents.size) |
|
|
floatArrayMulScalar(latents, scale, scaled) |
|
|
|
|
|
val latentTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(scaled), longArrayOf(1, 4, 64, 64)) |
|
|
latentTensor.use { lt -> |
|
|
val results = vae.run(mapOf("latent_sample" to lt)) |
|
|
results.use { |
|
|
val out = flattenToFloatArray(results[0].value) |
|
|
val w = 512 |
|
|
val h = 512 |
|
|
val hw = w * h |
|
|
val pixels = IntArray(hw) |
|
|
for (i in 0 until hw) { |
|
|
val r = out[i] |
|
|
val g = out[hw + i] |
|
|
val b = out[2 * hw + i] |
|
|
fun toU8(x: Float): Int { |
|
|
val v = ((x / 2.0f) + 0.5f) * 255.0f |
|
|
val c = if (v < 0f) 0 else if (v > 255f) 255 else v.toInt() |
|
|
return c |
|
|
} |
|
|
val rr = toU8(r) |
|
|
val gg = toU8(g) |
|
|
val bb = toU8(b) |
|
|
pixels[i] = (0xFF shl 24) or (rr shl 16) or (gg shl 8) or bb |
|
|
} |
|
|
val bmp = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888) |
|
|
bmp.setPixels(pixels, 0, w, 0, 0, w, h) |
|
|
|
|
|
val outFile = File(uriToPath(outPath)) |
|
|
outFile.parentFile?.let { ensureDir(it) } |
|
|
FileOutputStream(outFile).use { fos -> |
|
|
bmp.compress(Bitmap.CompressFormat.PNG, 100, fos) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun generateSd15FromIds(options: ReadableMap, promise: Promise) { |
|
|
try { |
|
|
val packDir = options.getString("packDir") ?: throw IllegalArgumentException("packDir is required") |
|
|
val condIdsArr = options.getArray("condIds") ?: throw IllegalArgumentException("condIds is required") |
|
|
val uncondIdsArr = options.getArray("uncondIds") ?: throw IllegalArgumentException("uncondIds is required") |
|
|
val steps = if (options.hasKey("steps")) options.getInt("steps") else 20 |
|
|
val guidance = if (options.hasKey("guidance")) options.getDouble("guidance").toFloat() else 7.5f |
|
|
val seed = if (options.hasKey("seed")) options.getDouble("seed").toLong() else -1L |
|
|
val outPath = options.getString("outPath") ?: throw IllegalArgumentException("outPath is required") |
|
|
|
|
|
val condIds = readInt64Ids(condIdsArr, 77) |
|
|
val uncondIds = readInt64Ids(uncondIdsArr, 77) |
|
|
|
|
|
val rng = if (seed == -1L) Random() else Random(seed) |
|
|
|
|
|
val preferNnapi = if (options.hasKey("preferNnapi")) options.getBoolean("preferNnapi") else false |
|
|
|
|
|
val text = getTextEncoder(packDir) |
|
|
val (unet, unetProvider) = getUnet(packDir, preferNnapi) |
|
|
val env = getEnv() |
|
|
|
|
|
fun encode(ids: LongArray): FloatArray { |
|
|
val t = OnnxTensor.createTensor(env, LongBuffer.wrap(ids), longArrayOf(1, 77)) |
|
|
t.use { it2 -> |
|
|
val r = text.run(mapOf("input_ids" to it2)) |
|
|
r.use { |
|
|
return flattenToFloatArray(r[0].value) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
val condEmb = encode(condIds) // [1,77,768] |
|
|
val uncondEmb = encode(uncondIds) |
|
|
|
|
|
val latentCount = 4 * 64 * 64 |
|
|
var latents = gaussianLatents(rng, latentCount) |
|
|
|
|
|
val alphasCum = alphasCumprod1000() |
|
|
val timesteps = ddimTimesteps(steps) |
|
|
|
|
|
val tmp = FloatArray(latentCount) |
|
|
val tmp2 = FloatArray(latentCount) |
|
|
|
|
|
fun unetEps(emb: FloatArray, t: Long, sample: FloatArray): FloatArray { |
|
|
val sampleTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(sample), longArrayOf(1, 4, 64, 64)) |
|
|
val tTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(longArrayOf(t)), longArrayOf()) |
|
|
val embTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(emb), longArrayOf(1, 77, 768)) |
|
|
sampleTensor.use { st -> |
|
|
tTensor.use { tt -> |
|
|
embTensor.use { et -> |
|
|
val r = unet.run(mapOf("sample" to st, "timestep" to tt, "encoder_hidden_states" to et)) |
|
|
r.use { rr -> |
|
|
return flattenToFloatArray(rr[0].value) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
for (i in 0 until timesteps.size) { |
|
|
val t = timesteps[i] |
|
|
val tPrev = if (i == timesteps.size - 1) -1 else timesteps[i + 1] |
|
|
|
|
|
val epsUncond = unetEps(uncondEmb, t.toLong(), latents) |
|
|
val epsText = unetEps(condEmb, t.toLong(), latents) |
|
|
// eps = epsUncond + guidance * (epsText - epsUncond) |
|
|
floatArraySub(epsText, epsUncond, tmp) |
|
|
floatArrayAddScaled(epsUncond, tmp, guidance, tmp2) |
|
|
|
|
|
val alphaT = alphasCum[t] |
|
|
val alphaPrev = if (tPrev == -1) 1.0 else alphasCum[tPrev] |
|
|
val sqrtAlphaT = kotlin.math.sqrt(alphaT).toFloat() |
|
|
val sqrtOneMinusAlphaT = kotlin.math.sqrt(1.0 - alphaT).toFloat() |
|
|
val sqrtAlphaPrev = kotlin.math.sqrt(alphaPrev).toFloat() |
|
|
val sqrtOneMinusAlphaPrev = kotlin.math.sqrt(1.0 - alphaPrev).toFloat() |
|
|
|
|
|
// pred_x0 = (x_t - sqrt(1-a_t)*eps)/sqrt(a_t) |
|
|
for (k in 0 until latentCount) { |
|
|
val predX0 = (latents[k] - sqrtOneMinusAlphaT * tmp2[k]) / sqrtAlphaT |
|
|
val dir = sqrtOneMinusAlphaPrev * tmp2[k] |
|
|
latents[k] = sqrtAlphaPrev * predX0 + dir |
|
|
} |
|
|
} |
|
|
|
|
|
decodeToPng(packDir, latents, outPath) |
|
|
|
|
|
val out = Arguments.createMap() |
|
|
out.putString("provider", unetProvider) |
|
|
out.putString("path", uriToPath(outPath)) |
|
|
out.putInt("steps", steps) |
|
|
promise.resolve(out) |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun runIdentity(input: ReadableArray, promise: Promise) { |
|
|
try { |
|
|
val floats = FloatArray(input.size()) { i -> (input.getDouble(i)).toFloat() } |
|
|
val (session, provider) = getSession(true) |
|
|
|
|
|
val tensor = OnnxTensor.createTensor(getEnv(), FloatBuffer.wrap(floats), longArrayOf(1, floats.size.toLong())) |
|
|
tensor.use { |
|
|
val results = session.run(mapOf("x" to it)) |
|
|
results.use { |
|
|
val out = flattenToFloatArray(results[0].value) |
|
|
val obj = Arguments.createMap() |
|
|
obj.putString("provider", provider) |
|
|
val arr = Arguments.createArray() |
|
|
for (v in out) arr.pushDouble(v.toDouble()) |
|
|
obj.putArray("output", arr) |
|
|
promise.resolve(obj) |
|
|
} |
|
|
} |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun getExternalFilesDirPath(promise: Promise) { |
|
|
try { |
|
|
val dir = reactApplicationContext.getExternalFilesDir(null) |
|
|
if (dir == null) { |
|
|
promise.reject("ORT_NATIVE_ERROR", "External files dir is not available") |
|
|
return |
|
|
} |
|
|
promise.resolve(dir.absolutePath) |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun statPath(fileUri: String, promise: Promise) { |
|
|
try { |
|
|
val res = Arguments.createMap() |
|
|
if (fileUri.startsWith("content://")) { |
|
|
// Best-effort: scoped storage URI (size is optional and may be unknown). |
|
|
val uri = Uri.parse(fileUri) |
|
|
val stream = reactApplicationContext.contentResolver.openInputStream(uri) |
|
|
stream?.close() |
|
|
res.putBoolean("exists", stream != null) |
|
|
res.putDouble("size", -1.0) |
|
|
promise.resolve(res) |
|
|
return |
|
|
} |
|
|
|
|
|
val path = uriToPath(fileUri) |
|
|
val f = File(path) |
|
|
res.putBoolean("exists", f.exists()) |
|
|
res.putDouble("size", if (f.exists()) f.length().toDouble() else 0.0) |
|
|
res.putString("path", f.absolutePath) |
|
|
promise.resolve(res) |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun readTextFile(fileUri: String, promise: Promise) { |
|
|
try { |
|
|
val res = if (fileUri.startsWith("content://")) { |
|
|
val uri = Uri.parse(fileUri) |
|
|
reactApplicationContext.contentResolver.openInputStream(uri) |
|
|
?: throw IllegalStateException("Unable to open content URI: $fileUri") |
|
|
} else { |
|
|
val path = uriToPath(fileUri) |
|
|
java.io.FileInputStream(path) |
|
|
} |
|
|
res.use { input -> |
|
|
val text = input.bufferedReader(Charsets.UTF_8).readText() |
|
|
promise.resolve(text) |
|
|
} |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
|
|
|
private fun typeInfoToMap(typeInfo: Any?): com.facebook.react.bridge.WritableMap { |
|
|
val m = Arguments.createMap() |
|
|
if (typeInfo == null) return m |
|
|
if (typeInfo is TensorInfo) { |
|
|
m.putString("type", typeInfo.type.toString()) |
|
|
val shapeArr = Arguments.createArray() |
|
|
for (d in typeInfo.shape) shapeArr.pushDouble(d.toDouble()) |
|
|
m.putArray("shape", shapeArr) |
|
|
} else { |
|
|
m.putString("type", (typeInfo as Any).javaClass.name) |
|
|
} |
|
|
return m |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun inspectModel(fileUri: String, promise: Promise) { |
|
|
try { |
|
|
val path = uriToPath(fileUri) |
|
|
val env = getEnv() |
|
|
val (session, provider) = try { |
|
|
val opts = SessionOptions() |
|
|
try { opts.addNnapi() } catch (_: Throwable) {} |
|
|
env.createSession(path, opts) to "nnapi" |
|
|
} catch (_: Throwable) { |
|
|
env.createSession(path, SessionOptions()) to "cpu" |
|
|
} |
|
|
|
|
|
session.use { |
|
|
val out = Arguments.createMap() |
|
|
out.putString("provider", provider) |
|
|
|
|
|
val inputsArr = Arguments.createArray() |
|
|
for ((name, nodeInfo) in session.inputInfo) { |
|
|
val mi = Arguments.createMap() |
|
|
mi.putString("name", name) |
|
|
mi.putMap("info", typeInfoToMap(nodeInfo.info)) |
|
|
inputsArr.pushMap(mi) |
|
|
} |
|
|
out.putArray("inputs", inputsArr) |
|
|
|
|
|
val outputsArr = Arguments.createArray() |
|
|
for ((name, nodeInfo) in session.outputInfo) { |
|
|
val mo = Arguments.createMap() |
|
|
mo.putString("name", name) |
|
|
mo.putMap("info", typeInfoToMap(nodeInfo.info)) |
|
|
outputsArr.pushMap(mo) |
|
|
} |
|
|
out.putArray("outputs", outputsArr) |
|
|
|
|
|
promise.resolve(out) |
|
|
} |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun inspectModelCpu(fileUri: String, promise: Promise) { |
|
|
try { |
|
|
val path = uriToPath(fileUri) |
|
|
val env = getEnv() |
|
|
val session = env.createSession(path, SessionOptions()) |
|
|
session.use { |
|
|
val out = Arguments.createMap() |
|
|
out.putString("provider", "cpu") |
|
|
|
|
|
val inputsArr = Arguments.createArray() |
|
|
for ((name, nodeInfo) in session.inputInfo) { |
|
|
val mi = Arguments.createMap() |
|
|
mi.putString("name", name) |
|
|
mi.putMap("info", typeInfoToMap(nodeInfo.info)) |
|
|
inputsArr.pushMap(mi) |
|
|
} |
|
|
out.putArray("inputs", inputsArr) |
|
|
|
|
|
val outputsArr = Arguments.createArray() |
|
|
for ((name, nodeInfo) in session.outputInfo) { |
|
|
val mo = Arguments.createMap() |
|
|
mo.putString("name", name) |
|
|
mo.putMap("info", typeInfoToMap(nodeInfo.info)) |
|
|
outputsArr.pushMap(mo) |
|
|
} |
|
|
out.putArray("outputs", outputsArr) |
|
|
|
|
|
promise.resolve(out) |
|
|
} |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
|
|
|
@ReactMethod |
|
|
fun unzip(zipFileUri: String, destDirUri: String, promise: Promise) { |
|
|
try { |
|
|
val destDir = File(uriToPath(destDirUri)) |
|
|
ensureDir(destDir) |
|
|
|
|
|
val inputStream = if (zipFileUri.startsWith("content://")) { |
|
|
val uri = Uri.parse(zipFileUri) |
|
|
reactApplicationContext.contentResolver.openInputStream(uri) |
|
|
?: throw IllegalStateException("Unable to open content URI: $zipFileUri") |
|
|
} else { |
|
|
val zipPath = uriToPath(zipFileUri) |
|
|
java.io.FileInputStream(zipPath) |
|
|
} |
|
|
|
|
|
val zis = java.util.zip.ZipInputStream(java.io.BufferedInputStream(inputStream)) |
|
|
zis.use { stream -> |
|
|
var entry = stream.nextEntry |
|
|
var count = 0 |
|
|
val buf = ByteArray(1024 * 64) |
|
|
while (entry != null) { |
|
|
val outFile = File(destDir, entry.name) |
|
|
if (!isSubPath(destDir, outFile)) { |
|
|
throw IllegalStateException("Zip entry escapes destination: \${entry.name}") |
|
|
} |
|
|
if (entry.isDirectory) { |
|
|
ensureDir(outFile) |
|
|
} else { |
|
|
outFile.parentFile?.let { ensureDir(it) } |
|
|
java.io.FileOutputStream(outFile).use { out -> |
|
|
while (true) { |
|
|
val read = stream.read(buf) |
|
|
if (read <= 0) break |
|
|
out.write(buf, 0, read) |
|
|
} |
|
|
} |
|
|
count += 1 |
|
|
} |
|
|
stream.closeEntry() |
|
|
entry = stream.nextEntry |
|
|
} |
|
|
val res = Arguments.createMap() |
|
|
res.putString("destPath", destDir.absolutePath) |
|
|
res.putInt("files", count) |
|
|
promise.resolve(res) |
|
|
} |
|
|
} catch (e: Throwable) { |
|
|
promise.reject("ORT_NATIVE_ERROR", e.message, e) |
|
|
} |
|
|
} |
|
|
} |
|
|
` |
|
|
); |
|
|
|
|
|
fs.writeFileSync( |
|
|
packagePath, |
|
|
`package com.anonymous.staticplaysd1mobile.ort |
|
|
|
|
|
import com.facebook.react.ReactPackage |
|
|
import com.facebook.react.bridge.NativeModule |
|
|
import com.facebook.react.bridge.ReactApplicationContext |
|
|
import com.facebook.react.uimanager.ViewManager |
|
|
|
|
|
class OrtPackage : ReactPackage { |
|
|
override fun createNativeModules(reactContext: ReactApplicationContext): List<NativeModule> { |
|
|
return listOf(OrtModule(reactContext)) |
|
|
} |
|
|
|
|
|
override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> { |
|
|
return emptyList() |
|
|
} |
|
|
} |
|
|
` |
|
|
); |
|
|
|
|
|
|
|
|
const assetDstDir = path.join(androidDir, 'app', 'src', 'main', 'assets', 'models'); |
|
|
const assetDst = path.join(assetDstDir, 'identity.onnx'); |
|
|
fs.mkdirSync(assetDstDir, { recursive: true }); |
|
|
const bytes = Buffer.from(IDENTITY_ONNX_B64, 'base64'); |
|
|
fs.writeFileSync(assetDst, bytes); |
|
|
|
|
|
return cfg; |
|
|
}, |
|
|
]); |
|
|
|
|
|
return config; |
|
|
} |
|
|
|
|
|
module.exports = withOrtNative; |
|
|
|