|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from transformers import AutoTokenizer |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent.parent |
|
|
MODEL_PATH = BASE_DIR / "summarizer" / "models" / "summarizer.tflite" |
|
|
TOKENIZER_DIR = BASE_DIR / "summarizer" / "models" / "flan_t5_custom" |
|
|
|
|
|
|
|
|
MAX_LEN = 256 |
|
|
|
|
|
|
|
|
def generate_tflite(prompt, interpreter, tokenizer): |
|
|
|
|
|
input_ids = tokenizer.encode(prompt, max_length=MAX_LEN, truncation=True, padding="max_length") |
|
|
input_ids = np.array([input_ids], dtype=np.int32) |
|
|
|
|
|
|
|
|
decoder_input_ids = np.zeros((1, MAX_LEN), dtype=np.int32) |
|
|
output_tokens = [0] |
|
|
|
|
|
|
|
|
input_details = interpreter.get_input_details() |
|
|
output_details = interpreter.get_output_details() |
|
|
|
|
|
|
|
|
generated_text = "" |
|
|
print(f"⏳ Generowanie dla promptu: '{prompt[:30]}...'") |
|
|
|
|
|
for i in range(MAX_LEN - 1): |
|
|
|
|
|
for j, token in enumerate(output_tokens): |
|
|
decoder_input_ids[0, j] = token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for detail in input_details: |
|
|
if "input_ids" in detail['name'] and "decoder" not in detail['name']: |
|
|
interpreter.set_tensor(detail['index'], input_ids) |
|
|
elif "decoder_input_ids" in detail['name']: |
|
|
interpreter.set_tensor(detail['index'], decoder_input_ids) |
|
|
|
|
|
interpreter.invoke() |
|
|
|
|
|
|
|
|
output_data = interpreter.get_tensor(output_details[0]['index']) |
|
|
|
|
|
|
|
|
next_token_logits = output_data[0, len(output_tokens) - 1, :] |
|
|
|
|
|
|
|
|
next_token = int(np.argmax(next_token_logits)) |
|
|
|
|
|
|
|
|
if next_token == 1: |
|
|
print("LOG: Otrzymano token EOS (1)") |
|
|
break |
|
|
|
|
|
output_tokens.append(next_token) |
|
|
|
|
|
|
|
|
word = tokenizer.decode([next_token]) |
|
|
generated_text += word |
|
|
print(f" Step {i}: {next_token} -> '{word}'") |
|
|
|
|
|
if len(output_tokens) >= MAX_LEN: |
|
|
break |
|
|
|
|
|
return generated_text.strip() |
|
|
|
|
|
|
|
|
def main(): |
|
|
if not MODEL_PATH.exists(): |
|
|
print(f"❌ Nie znaleziono pliku modelu w: {MODEL_PATH}") |
|
|
return |
|
|
|
|
|
print(f"🚀 Ładowanie modelu TFLite: {MODEL_PATH}") |
|
|
interpreter = tf.lite.Interpreter(model_path=str(MODEL_PATH)) |
|
|
interpreter.allocate_tensors() |
|
|
|
|
|
print(f"🚀 Ładowanie tokenizera z: {TOKENIZER_DIR}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR) |
|
|
|
|
|
|
|
|
sample_text = "Matura 2005 przykład RZECZPOSPOLITA POLSKA ŚWIADECTWO DOJRZAŁOŚCI Janina Kosińska-Iksińska" |
|
|
|
|
|
|
|
|
title = generate_tflite(f"headline: {sample_text}", interpreter, tokenizer) |
|
|
print(f"\n📌 FINALNY TYTUŁ TFLITE: {title}") |
|
|
|
|
|
|
|
|
summary = generate_tflite(f"summarize: {sample_text}", interpreter, tokenizer) |
|
|
print(f"\n📝 FINALNE PODSUMOWANIE TFLITE: {summary}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |