| import torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import gradio as gr |
|
|
| |
| model_checkpoint = "chikki2004/incident-bart-model" |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) |
|
|
| def predict(input_text): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=256) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| outputs = model.generate(**inputs, max_length=256) |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return decoded |
|
|
| iface = gr.Interface( |
| api_name="/predict", |
| fn=predict, |
| inputs=gr.Textbox(lines=5, placeholder="Enter log or log description..."), |
| outputs="text", |
| title="GK's Incident Attribute Predictor" |
| ) |
|
|
| iface.launch() |
|
|