package com.example.open.diffusion; import androidx.appcompat.app.AppCompatActivity; import androidx.appcompat.widget.AppCompatSpinner; import android.app.ProgressDialog; import android.graphics.Bitmap; import android.os.Bundle; import android.text.TextUtils; import android.view.View; import android.widget.EditText; import android.widget.ImageView; import android.widget.TextView; import android.widget.Toast; import com.example.open.diffusion.core.UNet; import com.example.open.diffusion.core.tokenizer.EngTokenizer; import com.example.open.diffusion.core.tokenizer.TextTokenizer; import java.io.File; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import ai.onnxruntime.OnnxTensor; public class MainActivity extends AppCompatActivity { private final ExecutorService exec = Executors.newCachedThreadPool(); private final int[] resolution = {192, 256, 320, 384, 448, 512}; private ImageView mImageView; private TextView mMsgView; private EditText mGuidanceView; private EditText mStepView; private EditText mPromptView; private EditText mNetPromptView; private AppCompatSpinner mWidthSpinner; private AppCompatSpinner mHeightSpinner; private ProgressDialog progressDialog; private EditText mSeedView; private UNet uNet; private TextTokenizer tokenizer; private boolean isWorking = false; @Override protected void onDestroy() { super.onDestroy(); try { uNet.close(); tokenizer.close(); }catch (Exception e){ e.printStackTrace(); } } @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); mImageView = findViewById(R.id.image); mMsgView = findViewById(R.id.msg); mGuidanceView = findViewById(R.id.guidance); mStepView = findViewById(R.id.step); mPromptView = findViewById(R.id.prompt); mWidthSpinner = findViewById(R.id.width); mHeightSpinner = findViewById(R.id.height); mNetPromptView = findViewById(R.id.neg_prompt); mSeedView = findViewById(R.id.seed); mWidthSpinner.setSelection(3); mHeightSpinner.setSelection(3); progressDialog = new ProgressDialog(MainActivity.this); uNet = new UNet(this, Device.CPU); tokenizer = new EngTokenizer(this); uNet.setCallback(new UNet.Callback() { @Override public void onStep(int maxStep, int step) { runOnUiThread(new MyRunnable() { @Override public void run() { mMsgView.setText(String.format("%d / %d", step + 1, maxStep)); } }); } @Override public void onBuildImage(int status, Bitmap bitmap) { runOnUiThread(new MyRunnable() { @Override public void run() { if (bitmap != null) mImageView.setImageBitmap(bitmap); } }); } @Override public void onComplete() { runOnUiThread(new MyRunnable() { @Override public void run() { mMsgView.setText("已完成"); } }); } @Override public void onStop() { } }); findViewById(R.id.copy).setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { progressDialog.show(); exec.execute(new MyRunnable() { @Override public void run() { try { FileUtils.copyAssets(getAssets(), "model", new File(PathManager.getAsssetOutputPath(MainActivity.this))); }catch (Exception e){ e.printStackTrace(); }finally { runOnUiThread(new Runnable() { @Override public void run() { progressDialog.dismiss(); } }); } } }); } }); findViewById(R.id.generate).setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { try { if (isWorking) return; isWorking = true; mMsgView.setText("初始化. . ."); exec.execute(createRunnable()); }catch (Exception e){ e.printStackTrace(); } } }); } private MyRunnable createRunnable(){ final String guidanceText = mGuidanceView.getText().toString(); final String stepText = mStepView.getText().toString(); final String prompt = mPromptView.getText().toString(); final String negPrompt = mNetPromptView.getText().toString(); final String seedText = mSeedView.getText().toString(); final int num_inference_steps = TextUtils.isEmpty(stepText) ? 8 : Integer.parseInt(stepText); final double guidance_scale = TextUtils.isEmpty(guidanceText) ? 7.5f : Float.valueOf(guidanceText); final long seed = TextUtils.isEmpty(seedText) ? 0 : Long.parseLong(seedText); UNet.WIDTH = resolution[mWidthSpinner.getSelectedItemPosition()]; UNet.HEIGHT = resolution[mHeightSpinner.getSelectedItemPosition()]; return new MyRunnable() { @Override public void run() { try { tokenizer.init(); int batch_size = 1; int[] textTokenized = tokenizer.encoder(prompt); int[] negTokenized = tokenizer.createUncondInput(negPrompt); OnnxTensor textPromptEmbeddings = tokenizer.tensor(textTokenized); OnnxTensor uncondEmbedding = tokenizer.tensor(negTokenized); float[][][] textEmbeddingArray = new float[2][tokenizer.getMaxLength()][768]; float[] textPromptEmbeddingArray = textPromptEmbeddings.getFloatBuffer().array(); float[] uncondEmbeddingArray = uncondEmbedding.getFloatBuffer().array(); for (int i = 0; i < textPromptEmbeddingArray.length; i++) { textEmbeddingArray[0][i / 768][i % 768] = uncondEmbeddingArray[i]; textEmbeddingArray[1][i / 768][i % 768] = textPromptEmbeddingArray[i]; } OnnxTensor textEmbeddings = OnnxTensor.createTensor(App.ENVIRONMENT, textEmbeddingArray); tokenizer.close(); uNet.init(); uNet.inference(seed, num_inference_steps, textEmbeddings, guidance_scale, batch_size, UNet.WIDTH, UNet.HEIGHT); }catch (Exception e){ runOnUiThread(new Runnable() { @Override public void run() { mMsgView.setText("Error"); } }); e.printStackTrace(); }finally { isWorking = false; } } }; } }