saikiranmansa's picture
Update app.py
ed9b623 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from huggingface_hub import login
# Hugging Face Authentication
hf_token = os.getenv("HUGGINGFACE_TOKEN", "").strip()
if not hf_token:
st.error("HUGGINGFACE_TOKEN not found. Please set your Hugging Face token.")
st.stop()
login(token=hf_token)
# Load Model & Tokenizer
model_name = "meta-llama/Llama-2-7b-chat-hf" # Use the chat model
@st.cache_resource
def load_model():
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
# Load model with FP16 (half-precision) on CPU
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # Force CPU usage
torch_dtype=torch.float16, # Use FP16 to reduce memory usage
token=hf_token
)
return tokenizer, model
tokenizer, model = load_model()
# Function to classify text using a prompt-based approach
def classify_text(text, classes):
# Create a prompt for classification
prompt = f"""
Classify the following text into one of these categories: {", ".join(classes)}.
Text: {text}
Category:
"""
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate the output
with torch.no_grad():
outputs = model.generate(**inputs, max_length=100, num_return_sequences=1)
# Decode the output
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the predicted class
predicted_class = decoded_output.split("Category:")[-1].strip()
return predicted_class
# Custom CSS to make all text red
st.markdown(
"""
<style>
/* Target all text elements */
body, h1, h2, h3, h4, h5, h6, p, div, span, input, textarea, button, label {
color: #E25822 !important;
}
</style>
""",
unsafe_allow_html=True
)
# Streamlit UI
st.title("πŸ“ Text Classification with LLaMA 2 Chat (FP16)")
st.write("Powered by LLaMA 2 Chat & Hugging Face")
# User Input
user_input = st.text_area("Enter the text to classify:")
# Define classes for classification
classes = ["Positive", "Negative", "Neutral"]
if st.button("Classify"):
if user_input:
# Perform classification
predicted_class = classify_text(user_input, classes)
# Display result
st.subheader("Predicted Class:")
st.write(predicted_class)
else:
st.warning("Please enter some text to classify.")
st.markdown("---")
st.write("πŸ” This app classifies text using the LLaMA 2 Chat model with FP16.")