|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import os |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HUGGINGFACE_TOKEN", "").strip() |
|
|
|
|
|
if not hf_token: |
|
|
st.error("HUGGINGFACE_TOKEN not found. Please set your Hugging Face token.") |
|
|
st.stop() |
|
|
|
|
|
login(token=hf_token) |
|
|
|
|
|
|
|
|
model_name = "meta-llama/Llama-2-7b-chat-hf" |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
device_map="cpu", |
|
|
torch_dtype=torch.float16, |
|
|
token=hf_token |
|
|
) |
|
|
|
|
|
return tokenizer, model |
|
|
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
|
|
|
def classify_text(text, classes): |
|
|
|
|
|
prompt = f""" |
|
|
Classify the following text into one of these categories: {", ".join(classes)}. |
|
|
Text: {text} |
|
|
Category: |
|
|
""" |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, max_length=100, num_return_sequences=1) |
|
|
|
|
|
|
|
|
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
predicted_class = decoded_output.split("Category:")[-1].strip() |
|
|
return predicted_class |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
/* Target all text elements */ |
|
|
body, h1, h2, h3, h4, h5, h6, p, div, span, input, textarea, button, label { |
|
|
color: #E25822 !important; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
st.title("π Text Classification with LLaMA 2 Chat (FP16)") |
|
|
st.write("Powered by LLaMA 2 Chat & Hugging Face") |
|
|
|
|
|
|
|
|
user_input = st.text_area("Enter the text to classify:") |
|
|
|
|
|
|
|
|
classes = ["Positive", "Negative", "Neutral"] |
|
|
|
|
|
if st.button("Classify"): |
|
|
if user_input: |
|
|
|
|
|
predicted_class = classify_text(user_input, classes) |
|
|
|
|
|
|
|
|
st.subheader("Predicted Class:") |
|
|
st.write(predicted_class) |
|
|
else: |
|
|
st.warning("Please enter some text to classify.") |
|
|
|
|
|
st.markdown("---") |
|
|
st.write("π This app classifies text using the LLaMA 2 Chat model with FP16.") |