navid72m commited on
Commit
6f9ac18
·
verified ·
1 Parent(s): 0456c8a

Update selection.py

Browse files
Files changed (1) hide show
  1. selection.py +19 -5
selection.py CHANGED
@@ -43,15 +43,28 @@ instruction = most_relevant_context[:300] + " " + Question
43
  gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
44
  gemma_lm.save('saved_model/gemma_2b_en')
45
 
46
- # Convert the saved model to TensorFlow Lite format with quantization
47
  saved_model_dir = 'saved_model/gemma_2b_en'
48
  converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
49
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
50
- tflite_model = converter.convert()
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Save the quantized model
53
  with open('gemma_2b_en_quantized.tflite', 'wb') as f:
54
- f.write(tflite_model)
55
 
56
  # Load the quantized model and run inference
57
  interpreter = tf.lite.Interpreter(model_path='gemma_2b_en_quantized.tflite')
@@ -61,8 +74,9 @@ input_details = interpreter.get_input_details()
61
  output_details = interpreter.get_output_details()
62
 
63
  def preprocess_input(instruction):
64
- # Convert the input to the required format and shape
65
- input_data = np.array([instruction], dtype=np.float32)
 
66
  return input_data
67
 
68
  input_data = preprocess_input(instruction)
 
43
  gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
44
  gemma_lm.save('saved_model/gemma_2b_en')
45
 
46
+ # Convert the saved model to TensorFlow Lite format with 8-bit full integer quantization
47
  saved_model_dir = 'saved_model/gemma_2b_en'
48
  converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
49
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
50
+
51
+ # Representative dataset function for quantization
52
+ def representative_dataset_gen():
53
+ for _ in range(100):
54
+ # Example input array, replace with your actual data
55
+ data = np.random.rand(1, 300).astype(np.float32)
56
+ yield [data]
57
+
58
+ converter.representative_dataset = representative_dataset_gen
59
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
60
+ converter.inference_input_type = tf.uint8
61
+ converter.inference_output_type = tf.uint8
62
+
63
+ tflite_model_quant = converter.convert()
64
 
65
  # Save the quantized model
66
  with open('gemma_2b_en_quantized.tflite', 'wb') as f:
67
+ f.write(tflite_model_quant)
68
 
69
  # Load the quantized model and run inference
70
  interpreter = tf.lite.Interpreter(model_path='gemma_2b_en_quantized.tflite')
 
74
  output_details = interpreter.get_output_details()
75
 
76
  def preprocess_input(instruction):
77
+ # Tokenization and padding to match input shape
78
+ # This is a placeholder; replace it with your actual preprocessing code
79
+ input_data = np.array([[ord(c) for c in instruction]], dtype=np.uint8)
80
  return input_data
81
 
82
  input_data = preprocess_input(instruction)