TRACES commited on
Commit
9f940ae
·
1 Parent(s): a17012f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -1
main.py CHANGED
@@ -24,6 +24,10 @@ def load_models():
24
  model=BertForSequenceClassification.from_pretrained("usmiva/bert-deepfake-bg", num_labels=2),
25
  tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-deepfake-bg"))
26
 
 
 
 
 
27
 
28
  def load_content():
29
  with open('resource/page_content.json', encoding='utf8') as json_file:
@@ -44,7 +48,8 @@ if 'lang' not in st.session_state:
44
  if all([
45
  'bert_gpt_result' not in st.session_state,
46
  'untrue_detector_result' not in st.session_state,
47
- 'bert_disinfo_result' not in st.session_state
 
48
  ]):
49
  st.session_state.bert_gpt_result = [{'label': '', 'score': 1}]
50
 
@@ -97,6 +102,7 @@ if st.session_state.agree:
97
  st.session_state.untrue_detector_probability = max(st.session_state.untrue_detector_probability[0], st.session_state.untrue_detector_probability[1])
98
 
99
  st.session_state.bert_disinfo_result = st.session_state.bert_disinfo(user_input)
 
100
 
101
 
102
 
@@ -127,6 +133,20 @@ if st.session_state.agree:
127
  str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
128
  content['bert_no_2'][st.session_state.lang], icon="✅")
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️")
132
 
 
24
  model=BertForSequenceClassification.from_pretrained("usmiva/bert-deepfake-bg", num_labels=2),
25
  tokenizer=AutoTokenizer.from_pretrained("usmiva/bert-deepfake-bg"))
26
 
27
+ st.session_state.emotions = pipeline(task="text-classification",
28
+ model=BertForSequenceClassification.from_pretrained("TRACES/emotions", use_auth_token=os.environ['ACCESS_TOKEN'], num_labels=2),
29
+ tokenizer=AutoTokenizer.from_pretrained("TRACES/emotions", use_auth_token=os.environ['ACCESS_TOKEN']))
30
+
31
 
32
  def load_content():
33
  with open('resource/page_content.json', encoding='utf8') as json_file:
 
48
  if all([
49
  'bert_gpt_result' not in st.session_state,
50
  'untrue_detector_result' not in st.session_state,
51
+ 'bert_disinfo_result' not in st.session_state,
52
+ 'emotions_result' not in st.session_state
53
  ]):
54
  st.session_state.bert_gpt_result = [{'label': '', 'score': 1}]
55
 
 
102
  st.session_state.untrue_detector_probability = max(st.session_state.untrue_detector_probability[0], st.session_state.untrue_detector_probability[1])
103
 
104
  st.session_state.bert_disinfo_result = st.session_state.bert_disinfo(user_input)
105
+ st.session_state.bert_disinfo_result = st.session_state.emotions(user_input)
106
 
107
 
108
 
 
133
  str(round(st.session_state.bert_disinfo_result[0]['score'] * 100, 2)) +
134
  content['bert_no_2'][st.session_state.lang], icon="✅")
135
 
136
+ if st.session_state.emotions[0]['score'] < 0.97:
137
+ st.warning(content['emotions_label_1'][st.session_state.lang] +
138
+ str(st.session_state.emotions[0]['label']) +
139
+ content['emotions_label_2'][st.session_state.lang]
140
+ str(round(st.session_state.emotions[0]['score'] * 100, 2)) +
141
+ content['emotions_label_3'][st.session_state.lang] +
142
+ content['emotions_label_4'][st.session_state.lang], icon = "⚠️")
143
+ else:
144
+ st.success(content['emotions_label_1'][st.session_state.lang] +
145
+ str(st.session_state.emotions[0]['label']) +
146
+ content['emotions_label_2'][st.session_state.lang]
147
+ str(round(st.session_state.emotions[0]['score'] * 100, 2)) +
148
+ content['emotions_label_3'][st.session_state.lang], icon="✅")
149
+
150
 
151
  st.info(content['disinformation_definition'][st.session_state.lang], icon="ℹ️")
152