Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from typing import Dict, List | |
| import torch | |
| torch.backends.cudnn.enabled = False | |
| import json | |
| import pickle | |
| from pathlib import Path | |
| from utils import Vocab | |
| from model import SeqClassifier | |
| import re | |
| # Set model parameters | |
| max_len = 128 | |
| hidden_size = 256 | |
| num_layers = 2 | |
| dropout = 0.1 | |
| bidirectional = True | |
| device = "cpu" | |
| ckpt_dir = Path("./ckpt/intent/") | |
| cache_dir = Path("./cache/intent/") | |
| # Load vocabulary and intent index mapping | |
| with open(cache_dir / "vocab.pkl", "rb") as f: | |
| vocab: Vocab = pickle.load(f) | |
| intent_idx_path = cache_dir / "intent2idx.json" | |
| intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text()) | |
| __idx2label = {idx: intent for intent, idx in intent2idx.items()} | |
| def idx2label(idx: int): | |
| return __idx2label[idx] | |
| # Set embedding layer size | |
| embeddings_size = (5621, 300) | |
| embeddings = torch.empty(embeddings_size) | |
| embeddings.to(device) | |
| # Load the best model | |
| best_model = SeqClassifier( | |
| embeddings=embeddings, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| bidirectional=bidirectional, | |
| num_class=len(intent2idx) | |
| ).to(device) | |
| # Define the path to the checkpoint file | |
| ckpt_path = ckpt_dir / "intent_checkpoint.pth" | |
| # Load the model's weights | |
| checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
| best_model.load_state_dict(checkpoint['model_state_dict']) | |
| # Set the model to evaluation mode | |
| best_model.eval() | |
| # Processing function to convert text to embedding indices | |
| def collate_fn(texts: str) -> torch.tensor: | |
| texts = re.findall(r"\w+|[^\w\s]", texts) | |
| encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0] | |
| encoded_text = torch.tensor(encoded_texts) | |
| return encoded_text | |
| # Classification function | |
| def classifier(text): | |
| encoded_text = collate_fn(text).to(device) | |
| output = best_model(encoded_text) | |
| Predicted_class = torch.argmax(output).item() | |
| prediction = idx2label(Predicted_class) | |
| return prediction | |
| import gradio as gr | |
| from gradio.components import Textbox | |
| import random | |
| def random_sample(): | |
| random_number = random.randint(0, len(examples) - 1) | |
| return examples[random_number] | |
| examples=[ | |
| "what are some fun things i can partake in in atlanta", | |
| "how do i make pumpkin pie", | |
| "what's the currency conversion between rubles and pounds", | |
| "please set an alarm for mid day", | |
| "how many hours will it take to get to my destination", | |
| "so i made a fraudulent transaction", | |
| "tell lydia and laura where i am located", | |
| "i want you to talk more quickly", | |
| "what's the deal with my health care", | |
| "What's the exchange rate for rubles to pounds", | |
| "How long will it take to reach my destination", | |
| "I suspect a fraudulent transaction on my account", | |
| "Inform Lydia and Laura of my current location", | |
| "I'd like you to speak faster", | |
| "Can you provide information about my health care", | |
| "Give me the details on my health insurance", | |
| "What's the local time now", | |
| "Find a recipe for chocolate chip cookies", | |
| "Check my credit card balance", | |
| "Translate 'Hello' to French", | |
| "Recommend a good restaurant nearby", | |
| ] | |
| title="Text Intent Classification" | |
| description=""" | |
| # Text Intent Classification | |
| This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result. | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Question Answering") as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| C_input = gr.Textbox(lines=3, label="Context paragraph", placeholder="Please enter text") | |
| A_output = Textbox(lines=3, label="Category") | |
| with gr.Row(): | |
| random_button = gr.Button("Random") | |
| classifier_button = gr.Button("classifier") | |
| random_button.click(random_sample, inputs=None, outputs=C_input) | |
| classifier_button.click(classifier, inputs=C_input, outputs=A_output) | |
| demo.launch(share=True) |