File size: 2,104 Bytes
2a99b58
75f5166
fa08c22
8912a27
2a99b58
fa08c22
8f55d75
3a01d1a
fa08c22
 
8f55d75
6a7d6fa
fa08c22
8912a27
 
6a7d6fa
75f5166
 
8912a27
fa08c22
92f0801
fa08c22
 
 
 
8912a27
 
 
 
 
 
 
1106581
fa08c22
8912a27
2a99b58
 
ba56c16
dc78b0a
2cf002f
75f5166
fa08c22
 
8912a27
 
75f5166
8912a27
dc78b0a
2a99b58
 
3a01d1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
import time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from quanto import quantize, freeze, qint8

model_dir = "t5flan"

# Load the quantized model and tokenizer
model = T5ForConditionalGeneration.from_pretrained(model_dir)
tokenizer = T5Tokenizer.from_pretrained(model_dir)

################### Modify this to add quantization of the model ##############################
quantized_model = T5ForConditionalGeneration.from_pretrained(model_dir)
quantize(quantized_model, weights=qint8, activations=None)
freeze(quantized_model)

# Define the inference function
def generate_text(prompt):
    # Measure time and generate text for the normal model
    start_time_normal = time.time()
    inputs = tokenizer(prompt, return_tensors='pt')
    outputs_normal = model.generate(**inputs, max_length=100, num_return_sequences=1)
    generated_text_normal = tokenizer.decode(outputs_normal[0], skip_special_tokens=True)
    end_time_normal = time.time()
    response_time_normal = end_time_normal - start_time_normal

    # Measure time and generate text for the quantized model
    start_time_quantized = time.time()
    outputs_quantized = quantized_model.generate(**inputs, max_length=100, num_return_sequences=1)
    generated_text_quantized = tokenizer.decode(outputs_quantized[0], skip_special_tokens=True)
    end_time_quantized = time.time()
    response_time_quantized = end_time_quantized - start_time_quantized

    return (generated_text_normal, f"{response_time_normal:.2f} seconds",
            generated_text_quantized, f"{response_time_quantized:.2f} seconds")

# Create a Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
    outputs=[
        gr.Textbox(label="Generated Text (Normal Model)"),
        gr.Textbox(label="Response Time (Normal Model)"),
        gr.Textbox(label="Generated Text (Quantized Model)"),
        gr.Textbox(label="Response Time (Quantized Model)")
    ],
    title="TinyLlama Text Generation Comparison"
)

# Launch the interface
iface.launch()