File size: 2,104 Bytes
2a99b58 75f5166 fa08c22 8912a27 2a99b58 fa08c22 8f55d75 3a01d1a fa08c22 8f55d75 6a7d6fa fa08c22 8912a27 6a7d6fa 75f5166 8912a27 fa08c22 92f0801 fa08c22 8912a27 1106581 fa08c22 8912a27 2a99b58 ba56c16 dc78b0a 2cf002f 75f5166 fa08c22 8912a27 75f5166 8912a27 dc78b0a 2a99b58 3a01d1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import gradio as gr
import time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from quanto import quantize, freeze, qint8
model_dir = "t5flan"
# Load the quantized model and tokenizer
model = T5ForConditionalGeneration.from_pretrained(model_dir)
tokenizer = T5Tokenizer.from_pretrained(model_dir)
################### Modify this to add quantization of the model ##############################
quantized_model = T5ForConditionalGeneration.from_pretrained(model_dir)
quantize(quantized_model, weights=qint8, activations=None)
freeze(quantized_model)
# Define the inference function
def generate_text(prompt):
# Measure time and generate text for the normal model
start_time_normal = time.time()
inputs = tokenizer(prompt, return_tensors='pt')
outputs_normal = model.generate(**inputs, max_length=100, num_return_sequences=1)
generated_text_normal = tokenizer.decode(outputs_normal[0], skip_special_tokens=True)
end_time_normal = time.time()
response_time_normal = end_time_normal - start_time_normal
# Measure time and generate text for the quantized model
start_time_quantized = time.time()
outputs_quantized = quantized_model.generate(**inputs, max_length=100, num_return_sequences=1)
generated_text_quantized = tokenizer.decode(outputs_quantized[0], skip_special_tokens=True)
end_time_quantized = time.time()
response_time_quantized = end_time_quantized - start_time_quantized
return (generated_text_normal, f"{response_time_normal:.2f} seconds",
generated_text_quantized, f"{response_time_quantized:.2f} seconds")
# Create a Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
outputs=[
gr.Textbox(label="Generated Text (Normal Model)"),
gr.Textbox(label="Response Time (Normal Model)"),
gr.Textbox(label="Generated Text (Quantized Model)"),
gr.Textbox(label="Response Time (Quantized Model)")
],
title="TinyLlama Text Generation Comparison"
)
# Launch the interface
iface.launch()
|