Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import T5ForConditionalGeneration, T5TokenizerFast | |
| tokenizer = T5TokenizerFast.from_pretrained("t5-base") | |
| # Define the quantized model architecture | |
| quantized_model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
| # Load the state dictionary | |
| state_dict = torch.load("quantized_model.pt") | |
| # Filter out keys that are not present in the quantized model | |
| filtered_state_dict = {k: v for k, v in state_dict.items() if k in quantized_model.state_dict()} | |
| # Load the filtered state dictionary into the quantized model | |
| quantized_model.load_state_dict(filtered_state_dict, strict=False) | |
| def encode_text(text): | |
| encoding = tokenizer.encode_plus( | |
| text, | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| return encoding["input_ids"], encoding["attention_mask"] | |
| def generate_summary(input_ids, attention_mask, model): | |
| model = model.to(input_ids.device) | |
| generated_ids = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_length=150, | |
| num_beams=2, | |
| repetition_penalty=2.5, | |
| length_penalty=1.0, | |
| early_stopping=True | |
| ) | |
| return generated_ids | |
| def decode_summary(generated_ids): | |
| summary = [tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| for gen_id in generated_ids] | |
| return "".join(summary) | |
| def summarize(text): | |
| input_ids, attention_mask = encode_text(text) | |
| generated_ids = generate_summary(input_ids, attention_mask, quantized_model) | |
| summary = decode_summary(generated_ids) | |
| return summary | |
| # Create Gradio interface | |
| input_text = gr.Textbox(lines=10, label="Input Text") | |
| output_text = gr.Textbox(label="Summary") | |
| gr.Interface( | |
| fn=summarize, | |
| inputs=input_text, | |
| outputs=output_text, | |
| title="Poem Pulse", | |
| description="Enter a Poem and get its Jist." | |
| ).launch() | |