saikiranmansa commited on
Commit
26200f9
·
verified ·
1 Parent(s): a1eca5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -24
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
  import os
5
  from huggingface_hub import login
6
 
7
  # Hugging Face Authentication
8
- hf_token = os.getenv("HUGGINGFACE_TOKEN", "").strip() # Remove any newline characters
9
 
10
  if not hf_token:
11
  st.error("HUGGINGFACE_TOKEN not found. Please set your Hugging Face token.")
@@ -14,48 +14,55 @@ if not hf_token:
14
  login(token=hf_token)
15
 
16
  # Load Model & Tokenizer
17
- model_name = "meta-llama/LLaMA-2-7b" # Replace with the correct LLaMA 2 model name
18
 
19
  @st.cache_resource
20
  def load_model():
21
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, use_fast=False)
22
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
23
- return tokenizer, model
 
 
 
 
 
24
 
25
- tokenizer, model = load_model()
 
 
 
26
 
27
  # Function to classify text
28
  def classify_text(user_input):
29
- # Tokenize the input
30
- inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
31
 
32
- # Get model predictions
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
 
36
- # Get predicted class
37
- predictions = torch.argmax(outputs.logits, dim=-1)
38
- predicted_class = predictions.item() # Assuming single label output
39
-
40
- return predicted_class
41
 
42
  # Streamlit UI
43
- st.title("Text Classification with LLaMA 2")
44
- st.write("Powered by LLaMA 2 and Hugging Face Transformers")
45
 
46
  # User Input
47
- user_input = st.text_area("Type your text here...")
48
 
49
  if st.button("Classify"):
50
  if user_input:
51
- # Get classification result
52
- predicted_class = classify_text(user_input)
53
-
54
  # Display result
55
- st.write(f"Predicted Class: {predicted_class}")
 
 
56
  else:
57
  st.warning("Please enter some text to classify.")
58
 
59
- # Add a footer or additional information if needed
60
  st.markdown("---")
61
- st.write("This application uses LLaMA 2 for text classification.")
 
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import LlamaTokenizer, AutoModelForSequenceClassification
4
  import os
5
  from huggingface_hub import login
6
 
7
  # Hugging Face Authentication
8
+ hf_token = os.getenv("HUGGINGFACE_TOKEN", "").strip()
9
 
10
  if not hf_token:
11
  st.error("HUGGINGFACE_TOKEN not found. Please set your Hugging Face token.")
 
14
  login(token=hf_token)
15
 
16
  # Load Model & Tokenizer
17
+ model_name = "meta-llama/Llama-2-7b-hf" # Ensure this is a fine-tuned classification model
18
 
19
  @st.cache_resource
20
  def load_model():
21
+ tokenizer = LlamaTokenizer.from_pretrained(model_name, token=hf_token)
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, token=hf_token)
23
+
24
+ # Move model to GPU if available
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ model.to(device)
27
+
28
+ return tokenizer, model, device
29
 
30
+ tokenizer, model, device = load_model()
31
+
32
+ # Define class labels (Update based on your dataset)
33
+ class_labels = ["Negative", "Neutral", "Positive"] # Modify if your model has different classes
34
 
35
  # Function to classify text
36
  def classify_text(user_input):
37
+ inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True).to(device)
 
38
 
 
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
 
42
+ logits = outputs.logits
43
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
44
+ predicted_class_idx = torch.argmax(probabilities, dim=-1).item()
45
+
46
+ return class_labels[predicted_class_idx], probabilities[0].cpu().tolist()
47
 
48
  # Streamlit UI
49
+ st.title("📝 Text Classification with LLaMA 2")
50
+ st.write("Powered by LLaMA 2 & Hugging Face")
51
 
52
  # User Input
53
+ user_input = st.text_area("Enter your text for classification:")
54
 
55
  if st.button("Classify"):
56
  if user_input:
57
+ predicted_class, probs = classify_text(user_input)
58
+
 
59
  # Display result
60
+ st.subheader(f"Predicted Class: {predicted_class}")
61
+ st.write(f"Confidence Scores: {dict(zip(class_labels, probs))}")
62
+
63
  else:
64
  st.warning("Please enter some text to classify.")
65
 
 
66
  st.markdown("---")
67
+ st.write("🔍 This app classifies text using a fine-tuned LLaMA 2 model.")
68
+