Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| # Load the tokenizer from Hugging Face | |
| token_model = "indolem/indobertweet-base-uncased" | |
| tokenizer = BertTokenizer.from_pretrained(token_model) | |
| # Define the model directory where your config.json and pytorch_model.bin are located | |
| model_directory = "pretrained_arief.model" # Make sure this directory has config.json and pytorch_model.bin | |
| # Load the model | |
| # If your weights are named differently, ensure the file is named pytorch_model.bin or modify the loading method | |
| model = BertForSequenceClassification.from_pretrained(model_directory) | |
| model.eval() # Set the model to evaluation mode | |
| # Check if CUDA is available and set the device accordingly | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| def classify_transaction(notes): | |
| # Tokenize the input text | |
| inputs = tokenizer.encode_plus( | |
| notes, | |
| None, | |
| add_special_tokens=True, | |
| max_length=256, | |
| padding='max_length', | |
| return_token_type_ids=False, | |
| return_attention_mask=True, | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| # Move tensors to the same device as the model | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| # Model in evaluation mode | |
| model.eval() | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| # Extract logits and convert to probabilities | |
| logits = outputs[0] | |
| probabilities = torch.softmax(logits, dim=1) | |
| # Get the predicted class | |
| predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy() | |
| # Return the predicted class | |
| return f"Predicted Category: {predicted_class}" | |
| # Creating the Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_transaction, | |
| inputs=gr.Textbox(lines=3, placeholder="Enter Transaction Notes Here", label="Transaction Notes"), | |
| outputs=gr.Text(label="Classification Result"), | |
| title="Transaction Category Classifier", | |
| description="Enter transaction notes to get the predicted category.", | |
| live=True # Update the output as soon as the input changes | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |