Spaces:
Runtime error
Runtime error
Update selection.py
Browse files- selection.py +19 -5
selection.py
CHANGED
|
@@ -43,15 +43,28 @@ instruction = most_relevant_context[:300] + " " + Question
|
|
| 43 |
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
|
| 44 |
gemma_lm.save('saved_model/gemma_2b_en')
|
| 45 |
|
| 46 |
-
# Convert the saved model to TensorFlow Lite format with quantization
|
| 47 |
saved_model_dir = 'saved_model/gemma_2b_en'
|
| 48 |
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
| 49 |
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Save the quantized model
|
| 53 |
with open('gemma_2b_en_quantized.tflite', 'wb') as f:
|
| 54 |
-
f.write(
|
| 55 |
|
| 56 |
# Load the quantized model and run inference
|
| 57 |
interpreter = tf.lite.Interpreter(model_path='gemma_2b_en_quantized.tflite')
|
|
@@ -61,8 +74,9 @@ input_details = interpreter.get_input_details()
|
|
| 61 |
output_details = interpreter.get_output_details()
|
| 62 |
|
| 63 |
def preprocess_input(instruction):
|
| 64 |
-
#
|
| 65 |
-
|
|
|
|
| 66 |
return input_data
|
| 67 |
|
| 68 |
input_data = preprocess_input(instruction)
|
|
|
|
| 43 |
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
|
| 44 |
gemma_lm.save('saved_model/gemma_2b_en')
|
| 45 |
|
| 46 |
+
# Convert the saved model to TensorFlow Lite format with 8-bit full integer quantization
|
| 47 |
saved_model_dir = 'saved_model/gemma_2b_en'
|
| 48 |
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
| 49 |
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 50 |
+
|
| 51 |
+
# Representative dataset function for quantization
|
| 52 |
+
def representative_dataset_gen():
|
| 53 |
+
for _ in range(100):
|
| 54 |
+
# Example input array, replace with your actual data
|
| 55 |
+
data = np.random.rand(1, 300).astype(np.float32)
|
| 56 |
+
yield [data]
|
| 57 |
+
|
| 58 |
+
converter.representative_dataset = representative_dataset_gen
|
| 59 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
| 60 |
+
converter.inference_input_type = tf.uint8
|
| 61 |
+
converter.inference_output_type = tf.uint8
|
| 62 |
+
|
| 63 |
+
tflite_model_quant = converter.convert()
|
| 64 |
|
| 65 |
# Save the quantized model
|
| 66 |
with open('gemma_2b_en_quantized.tflite', 'wb') as f:
|
| 67 |
+
f.write(tflite_model_quant)
|
| 68 |
|
| 69 |
# Load the quantized model and run inference
|
| 70 |
interpreter = tf.lite.Interpreter(model_path='gemma_2b_en_quantized.tflite')
|
|
|
|
| 74 |
output_details = interpreter.get_output_details()
|
| 75 |
|
| 76 |
def preprocess_input(instruction):
|
| 77 |
+
# Tokenization and padding to match input shape
|
| 78 |
+
# This is a placeholder; replace it with your actual preprocessing code
|
| 79 |
+
input_data = np.array([[ord(c) for c in instruction]], dtype=np.uint8)
|
| 80 |
return input_data
|
| 81 |
|
| 82 |
input_data = preprocess_input(instruction)
|