saikiranmansa commited on
Commit
4846bb0
Β·
verified Β·
1 Parent(s): aa3466e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -13
app.py CHANGED
@@ -14,7 +14,7 @@ if not hf_token:
14
  login(token=hf_token)
15
 
16
  # Load Model & Tokenizer
17
- model_name = "meta-llama/LLaMA-2-7b" # Use the correct model name
18
 
19
  @st.cache_resource
20
  def load_model():
@@ -32,31 +32,49 @@ def load_model():
32
 
33
  tokenizer, model = load_model()
34
 
35
- # Function to generate text (example for usage)
36
- def generate_text(prompt):
 
 
 
 
 
 
 
 
37
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
38
 
 
39
  with torch.no_grad():
40
- outputs = model.generate(**inputs, max_length=50) # Adjust max_length as needed
 
 
 
41
 
42
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
43
 
44
  # Streamlit UI
45
- st.title("πŸ“ Text Generation with LLaMA 2")
46
  st.write("Powered by LLaMA 2 & Hugging Face")
47
 
48
  # User Input
49
- user_input = st.text_area("Enter your prompt:")
 
 
 
50
 
51
- if st.button("Generate"):
52
  if user_input:
53
- generated_text = generate_text(user_input)
 
54
 
55
  # Display result
56
- st.subheader("Generated Text:")
57
- st.write(generated_text)
58
  else:
59
- st.warning("Please enter a prompt.")
60
 
61
  st.markdown("---")
62
- st.write("πŸ” This app generates text using the LLaMA 2 model.")
 
14
  login(token=hf_token)
15
 
16
  # Load Model & Tokenizer
17
+ model_name = "meta-llama/LLaMA-2-7b-chat-hf" # Use the chat version for better instruction-following
18
 
19
  @st.cache_resource
20
  def load_model():
 
32
 
33
  tokenizer, model = load_model()
34
 
35
+ # Function to classify text using a prompt-based approach
36
+ def classify_text(text, classes):
37
+ # Create a prompt for classification
38
+ prompt = f"""
39
+ Classify the following text into one of these categories: {", ".join(classes)}.
40
+ Text: {text}
41
+ Category:
42
+ """
43
+
44
+ # Tokenize the prompt
45
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
 
47
+ # Generate the output
48
  with torch.no_grad():
49
+ outputs = model.generate(**inputs, max_length=100, num_return_sequences=1)
50
+
51
+ # Decode the output
52
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
 
54
+ # Extract the predicted class
55
+ predicted_class = decoded_output.split("Category:")[-1].strip()
56
+ return predicted_class
57
 
58
  # Streamlit UI
59
+ st.title("πŸ“ Text Classification with LLaMA 2")
60
  st.write("Powered by LLaMA 2 & Hugging Face")
61
 
62
  # User Input
63
+ user_input = st.text_area("Enter the text to classify:")
64
+
65
+ # Define classes for classification
66
+ classes = ["Positive", "Negative", "Neutral"]
67
 
68
+ if st.button("Classify"):
69
  if user_input:
70
+ # Perform classification
71
+ predicted_class = classify_text(user_input, classes)
72
 
73
  # Display result
74
+ st.subheader("Predicted Class:")
75
+ st.write(predicted_class)
76
  else:
77
+ st.warning("Please enter some text to classify.")
78
 
79
  st.markdown("---")
80
+ st.write("πŸ” This app classifies text using the LLaMA 2 model.")