kadabengaran commited on
Commit
5ca6171
·
1 Parent(s): 2d1aa85

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +133 -0
main.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import torch
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import re
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from stqdm import stqdm
8
+ from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
9
+ except Exception as e:
10
+ print(e)
11
+
12
+ # Config
13
+ MODELS_PATH = "kadabengaran/distilbert-base-uncased-lora-text-classification"
14
+
15
+ id2label= {0: 'Other', 1: 'Problem Discovery', 2: 'Information Seeking', 3: 'Feature Request'}
16
+ label2id= {'Other': 0, 'Problem Discovery': 1, 'Information Seeking': 2, 'Feature Request': 3}
17
+ numLabels= 4
18
+
19
+ def get_device():
20
+ if torch.cuda.is_available():
21
+ return torch.device('cuda')
22
+ else:
23
+ return torch.device('cpu')
24
+
25
+ USE_CUDA = False
26
+ device = get_device()
27
+ if device.type == 'cuda':
28
+ USE_CUDA = True
29
+
30
+ # Get the Keys
31
+ def get_key(val, my_dict):
32
+ for key, value in my_dict.items():
33
+ if val == value:
34
+ return key
35
+
36
+ def load_tokenizer(model_path):
37
+ # create tokenizer
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path, add_prefix_space=True)
39
+ return tokenizer
40
+
41
+ def remove_special_characters(text):
42
+ # case folding
43
+ text = text.lower()
44
+
45
+ # menghapus karakter khusus
46
+ text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
47
+ text = re.sub(r'[0-9]', ' ', text)
48
+
49
+ # replace multiple whitespace characters with a single space
50
+ text = re.sub(r"\s+", " ", text)
51
+
52
+ return text
53
+
54
+ def load_model():
55
+ config = PeftConfig.from_pretrained(MODELS_PATH)
56
+ inference_model = AutoModelForSequenceClassification.from_pretrained(
57
+ config.base_model_name_or_path, num_labels=numLabels, id2label=id2label, label2id=label2id
58
+ )
59
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
60
+ model = PeftModel.from_pretrained(inference_model, MODELS_PATH)
61
+ return model, tokenizer
62
+
63
+ def classify_single(text, model, tokenizer, device):
64
+
65
+ if device.type == 'cuda':
66
+ model.cuda()
67
+
68
+ # tokenize text
69
+ inputs = tokenizer.encode(text, return_tensors="pt").to(device)
70
+
71
+ # compute logits
72
+ logits = model(inputs).logits
73
+ # convert logits to label
74
+ predictions = torch.argmax(logits)
75
+ return id2label[predictions.tolist()]
76
+
77
+
78
+ tab_labels = ["Single Input", "Multiple Input"]
79
+ class App:
80
+ def __init__(self):
81
+ self.fileTypes = ["csv"]
82
+ self.default_tab_selected = tab_labels[0]
83
+ self.input_text = None
84
+ self.csv_input = None
85
+ self.csv_process = None
86
+
87
+ def run(self):
88
+ model, tokenizer = load_model()
89
+ html_temp = """
90
+ <div style="padding:10px">
91
+ <h1 style="color:white;text-align:center;">User Question Classification</h1>
92
+ </div>
93
+ """
94
+ st.markdown(html_temp, unsafe_allow_html=True)
95
+ st.markdown("")
96
+ if USE_CUDA:
97
+ st.sidebar.markdown(footer,unsafe_allow_html=True)
98
+ self.render_single_input()
99
+ st.divider()
100
+ self.render_process_button(model, tokenizer, device)
101
+
102
+
103
+ def render_single_input(self):
104
+ self.input_text = st.text_area("Enter Text Here", placeholder="Type Here")
105
+
106
+
107
+ def render_process_button(self, model, tokenizer, device):
108
+ if st.button("Process"):
109
+ input_text = self.input_text
110
+ if input_text:
111
+ classification_result = classify_single(input_text, model, tokenizer, device)
112
+ st.write("Classification result:", classification_result)
113
+ else:
114
+ st.warning('Please enter text to process', icon="⚠️")
115
+
116
+
117
+ footer="""<style>
118
+ .footer {
119
+ position: fixed;
120
+ left: 10;
121
+ bottom: 0;
122
+ width: 100%;
123
+ color: #ffa9365e;
124
+ }
125
+ </style>
126
+ <div class="footer">
127
+ <p>CUDA enabled</p>
128
+ </div>
129
+ """
130
+
131
+ if __name__ == "__main__":
132
+ app = App()
133
+ app.run()