danielritchie commited on
Commit
583fc37
·
1 Parent(s): 5b3e7f7

tiny model, tiny runtime

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -3
  2. utils/color_model.py +22 -16
requirements.txt CHANGED
@@ -1,7 +1,5 @@
1
  gradio
2
- torch
3
- transformers
4
  matplotlib
5
  numpy
 
6
  huggingface_hub
7
-
 
1
  gradio
 
 
2
  matplotlib
3
  numpy
4
+ tflite-runtime
5
  huggingface_hub
 
utils/color_model.py CHANGED
@@ -1,37 +1,43 @@
1
- import torch
2
  from huggingface_hub import hf_hub_download
 
3
 
4
- # Download model file from HF
 
5
  model_path = hf_hub_download(
6
  repo_id="danielritchie/vibe-color-model",
7
- filename="pytorch_model.bin"
8
  )
9
 
10
- # Load full torch model
11
- model = torch.load(model_path, map_location=torch.device("cpu"))
12
- model.eval()
 
 
 
13
 
14
 
15
  def infer_color(vad):
16
- input_tensor = torch.tensor([[
17
  vad["V"],
18
  vad["A"],
19
  vad["D"],
20
  vad["Cx"],
21
  vad["Co"]
22
- ]], dtype=torch.float32)
23
 
24
- with torch.no_grad():
25
- output = model(input_tensor)
 
26
 
27
- r, g, b, e, i = output[0].tolist()
28
 
29
  return {
30
- "R": r,
31
- "G": g,
32
- "B": b,
33
- "E": e,
34
- "I": i
35
  }
36
 
37
 
 
1
+ import numpy as np
2
  from huggingface_hub import hf_hub_download
3
+ import tflite_runtime.interpreter as tflite
4
 
5
+
6
+ # Download TFLite model file
7
  model_path = hf_hub_download(
8
  repo_id="danielritchie/vibe-color-model",
9
+ filename="vibe_model.tflite"
10
  )
11
 
12
+ # Load interpreter
13
+ interpreter = tflite.Interpreter(model_path=model_path)
14
+ interpreter.allocate_tensors()
15
+
16
+ input_details = interpreter.get_input_details()
17
+ output_details = interpreter.get_output_details()
18
 
19
 
20
  def infer_color(vad):
21
+ input_data = np.array([[
22
  vad["V"],
23
  vad["A"],
24
  vad["D"],
25
  vad["Cx"],
26
  vad["Co"]
27
+ ]], dtype=np.float32)
28
 
29
+ interpreter.set_tensor(input_details[0]["index"], input_data)
30
+ interpreter.invoke()
31
+ output_data = interpreter.get_tensor(output_details[0]["index"])
32
 
33
+ r, g, b, e, i = output_data[0]
34
 
35
  return {
36
+ "R": float(r),
37
+ "G": float(g),
38
+ "B": float(b),
39
+ "E": float(e),
40
+ "I": float(i)
41
  }
42
 
43