VenujaDeSilva commited on
Commit
9f0e60b
ยท
verified ยท
1 Parent(s): a49e3af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import joblib
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ # -------------------------------------
7
+ # ๐ŸŽจ STREAMLIT PAGE CONFIG
8
+ # -------------------------------------
9
+ st.set_page_config(
10
+ page_title="StackOverflow Tag Predictor",
11
+ page_icon="๐ŸŽฏ",
12
+ layout="centered",
13
+ )
14
+
15
+ # -------------------------------------
16
+ # ๐ŸŒˆ CUSTOM CSS FOR BEAUTIFUL UI
17
+ # -------------------------------------
18
+ st.markdown("""
19
+ <style>
20
+ body {
21
+ background-color: #F2F2F7;
22
+ }
23
+ .big-title {
24
+ font-size: 40px;
25
+ font-weight: 900;
26
+ text-align: center;
27
+ margin-bottom: -10px;
28
+ color: #4A4AFC;
29
+ }
30
+ .subtitle {
31
+ text-align: center;
32
+ color: #666;
33
+ font-size: 18px;
34
+ }
35
+ .result-tag {
36
+ background-color: #4A4AFC;
37
+ padding: 10px 18px;
38
+ border-radius: 12px;
39
+ color: white;
40
+ display: inline-block;
41
+ font-size: 20px;
42
+ margin: 5px;
43
+ animation: fadeIn 0.6s ease-out;
44
+ }
45
+ @keyframes fadeIn {
46
+ from {opacity: 0; transform: translateY(10px);}
47
+ to {opacity: 1; transform: translateY(0);}
48
+ }
49
+ </style>
50
+ """, unsafe_allow_html=True)
51
+
52
+ # -------------------------------------
53
+ # ๐Ÿ“ฆ LOAD MODEL + TOKENIZER
54
+ # -------------------------------------
55
+ @st.cache_resource
56
+ def load_model():
57
+ model = AutoModelForSequenceClassification.from_pretrained(".")
58
+ tokenizer = AutoTokenizer.from_pretrained(".")
59
+ return model, tokenizer
60
+
61
+ model, tokenizer = load_model()
62
+
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ model = model.to(device)
65
+
66
+ # -------------------------------------
67
+ # ๐Ÿ”ค LOAD LABEL ENCODER
68
+ # -------------------------------------
69
+ label_encoder = joblib.load("label_encoder.joblib")
70
+ id2label = {i: label for i, label in enumerate(label_encoder.classes_)}
71
+
72
+ # -------------------------------------
73
+ # ๐Ÿ”ฎ PREDICTION FUNCTION
74
+ # -------------------------------------
75
+ def predict_tag(text):
76
+ encoding = tokenizer(
77
+ text,
78
+ truncation=True,
79
+ padding=True,
80
+ max_length=128,
81
+ return_tensors="pt"
82
+ )
83
+ encoding = {k: v.to(device) for k, v in encoding.items()}
84
+
85
+ with torch.no_grad():
86
+ outputs = model(**encoding)
87
+
88
+ pred_id = torch.argmax(outputs.logits, dim=-1).item()
89
+ tag = id2label[pred_id]
90
+ confidence = torch.softmax(outputs.logits, dim=-1).max().item()
91
+
92
+ return tag, confidence
93
+
94
+ # -------------------------------------
95
+ # ๐Ÿ–ฅ๏ธ UI LAYOUT
96
+ # -------------------------------------
97
+ st.markdown("<p class='big-title'>๐ŸŽฏ StackOverflow Tag Predictor</p>", unsafe_allow_html=True)
98
+ st.markdown("<p class='subtitle'>Powered by DistilBERT โ€ข Predicts the most likely tag from a question title</p>", unsafe_allow_html=True)
99
+
100
+ st.write("")
101
+
102
+ user_input = st.text_area(
103
+ "๐Ÿ’ฌ Enter a StackOverflow question title:",
104
+ height=120,
105
+ placeholder="Example: \"How to fix NullPointerException in Java?\""
106
+ )
107
+
108
+ if st.button("๐Ÿ” Predict Tag", use_container_width=True):
109
+ if user_input.strip() == "":
110
+ st.warning("Please enter a question title.")
111
+ else:
112
+ with st.spinner("Analyzing text using AI magic... โœจ"):
113
+ tag, confidence = predict_tag(user_input)
114
+
115
+ st.success("Prediction complete!")
116
+
117
+ st.markdown(f"<div class='result-tag'>{tag}</div>", unsafe_allow_html=True)
118
+ st.markdown(
119
+ f"### ๐Ÿ”ฅ Confidence: **{confidence*100:.2f}%**"
120
+ )
121
+
122
+ st.info("Try another title!")