studfaceval / HF_inference.py
MENG21's picture
7f80a20
# import requests
# import time
# import streamlit as st
# import os
# # SECRET_TOKEN
# SECRET_TOKEN = os.getenv("HF_IBOA")
# DISTILIBERT = "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-distilbert-base-uncased"
# BERTLARGE = "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-large-uncased"
# BERTBASE = "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-base-uncased"
# headers = {"Authorization": SECRET_TOKEN}
# # @st.cache_resource
# @st.cache_resource(experimental_allow_widgets=True, show_spinner=False)
# def query(payload, selected_model):
# if selected_model == "DISTILIBERT MODEL":
# API_URL = DISTILIBERT
# elif selected_model == "BERT-LARGE MODEL":
# API_URL = BERTLARGE
# elif selected_model == "BERT-BASE MODEL":
# API_URL = BERTBASE
# else:
# API_URL = DISTILIBERT
# start_time = time.time()
# counter = 0
# with st.spinner("Processing..."):
# while True:
# response = requests.post(API_URL, headers=headers, json=payload)
# # st.write(response)
# if response.status_code == 200:
# return response.json()
# else:
# time.sleep(1) # Wait for 1 second before retrying
# def analyze_sintement(text, selected_model):
# output = query({"inputs": text}, selected_model)
# if output:
# # st.success(f"Translation complete!")
# return output[0][0]['label'], output[0][0]['score']
# else:
# st.warning("Error! Please try again.")
import requests
import time
import streamlit as st
import os
# Define constants for API URLs
MODEL_URLS = {
"DISTILIBERT MODEL": "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-distilbert-base-uncased",
"BERT-LARGE MODEL": "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-large-uncased",
"BERT-BASE MODEL": "https://api-inference.huggingface.co/models/MENG21/stud-fac-eval-bert-base-uncased"
}
# SECRET_TOKEN
SECRET_TOKEN = os.getenv("HF_IBOA")
# Set headers
headers = {"Authorization": SECRET_TOKEN}
# Define retry parameters
MAX_RETRIES = 3
RETRY_INTERVAL = 1 # in seconds
# @st.cache_resource(experimental_allow_widgets=True, show_spinner=False)
def query(payload, selected_model):
# st.write(selected_model)
API_URL = MODEL_URLS.get(selected_model, MODEL_URLS[selected_model]) # Get API URL based on selected model
for retry in range(MAX_RETRIES):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
return response.json()
else:
st.info("loading..")
time.sleep(RETRY_INTERVAL)
return None
def analyze_sintement(text, selected_model):
# print(headers)
output = query({"inputs": text}, selected_model)
if output:
return output[0][0]['label'], output[0][0]['score']
else:
st.warning("Error! Please try again.")
pass