File size: 1,971 Bytes
8fd4eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import tensorflow as tf
from transformers import TFT5ForConditionalGeneration, AutoTokenizer
from pathlib import Path

BASE_DIR = Path(__file__).resolve().parent.parent
MODEL_INPUT_DIR = BASE_DIR / "summarizer" / "models" / "flan_t5_custom"
TFLITE_OUTPUT_FILE = BASE_DIR / "summarizer" / "models" / "summarizer.tflite"

# USTAWAMY IDENTYCZNE WARTOŚCI - to rozwiązuje błąd "not broadcastable"
MAX_LEN = 256


def convert():
    print(f"🚀 Konwersja z wyrównaniem kształtów do {MAX_LEN}...")

    model = TFT5ForConditionalGeneration.from_pretrained(MODEL_INPUT_DIR, from_pt=True)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_INPUT_DIR)

    class T5MergedModel(tf.Module):
        def __init__(self, model):
            super(T5MergedModel, self).__init__()
            self.model = model

        @tf.function(input_signature=[
            tf.TensorSpec([1, MAX_LEN], tf.int32, name="input_ids"),
            tf.TensorSpec([1, MAX_LEN], tf.int32, name="decoder_input_ids")
        ])
        def __call__(self, input_ids, decoder_input_ids):
            # training=False jest kluczowe dla usunięcia węzłów treningowych
            output = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, training=False)
            return output.logits

    t5_module = T5MergedModel(model)
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [t5_module.__call__.get_concrete_function()], t5_module
    )

    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]

    # Optymalizacja pod kątem rozmiaru i stabilności
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float32]

    tflite_model = converter.convert()
    with open(TFLITE_OUTPUT_FILE, "wb") as f:
        f.write(tflite_model)

    print(f"✨ Model gotowy: {TFLITE_OUTPUT_FILE}")


if __name__ == "__main__":
    convert()