File size: 3,061 Bytes
e4fe207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f80a20
e4fe207
 
aec2e12
e4fe207
 
 
aec2e12
e4fe207
 
 
 
7f80a20
e4fe207
 
 
 
 
aec2e12
e4fe207
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# import requests
# import time
# import streamlit as st
# import os


# # SECRET_TOKEN
# SECRET_TOKEN = os.getenv("HF_IBOA")

# DISTILIBERT = "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-distilbert-base-uncased"
# BERTLARGE = "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-large-uncased"
# BERTBASE = "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-base-uncased"

# headers = {"Authorization": SECRET_TOKEN}

# # @st.cache_resource
# @st.cache_resource(experimental_allow_widgets=True, show_spinner=False)
# def query(payload, selected_model):
#     if selected_model == "DISTILIBERT MODEL":
#         API_URL = DISTILIBERT
#     elif selected_model == "BERT-LARGE MODEL":
#         API_URL = BERTLARGE
#     elif selected_model == "BERT-BASE MODEL":
#         API_URL = BERTBASE
#     else:
#         API_URL = DISTILIBERT

#     start_time = time.time()
#     counter = 0
#     with st.spinner("Processing..."):
#         while True:
#             response = requests.post(API_URL, headers=headers, json=payload)
#             # st.write(response)
#             if response.status_code == 200:
                
#                 return response.json()
#             else:
#                 time.sleep(1)  # Wait for 1 second before retrying

# def analyze_sintement(text, selected_model):
#     output  = query({"inputs": text}, selected_model)
#     if output:
#         # st.success(f"Translation complete!")
#         return output[0][0]['label'], output[0][0]['score']
#     else:
#         st.warning("Error! Please try again.")



import requests
import time
import streamlit as st
import os

# Define constants for API URLs
MODEL_URLS = {
    "DISTILIBERT MODEL": "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-distilbert-base-uncased",
    "BERT-LARGE MODEL": "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-large-uncased",
    "BERT-BASE MODEL": "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-base-uncased"
}

# SECRET_TOKEN
SECRET_TOKEN = os.getenv("HF_IBOA")

# Set headers
headers = {"Authorization": SECRET_TOKEN}

# Define retry parameters
MAX_RETRIES = 3
RETRY_INTERVAL = 1  # in seconds

# @st.cache_resource(experimental_allow_widgets=True, show_spinner=False)
def query(payload, selected_model):
    # st.write(selected_model)
    
    API_URL = MODEL_URLS.get(selected_model, MODEL_URLS[selected_model])  # Get API URL based on selected model

    for retry in range(MAX_RETRIES):
        
        response = requests.post(API_URL, headers=headers, json=payload)
        if response.status_code == 200:
            return response.json()
        else:
            st.info("loading..")
            time.sleep(RETRY_INTERVAL)

    return None

def analyze_sintement(text, selected_model):
    # print(headers)
    output = query({"inputs": text}, selected_model)
    if output:
        return output[0][0]['label'], output[0][0]['score']
    else:
        st.warning("Error! Please try again.")
        pass