saikiranmansa commited on
Commit
e57cc98
Β·
verified Β·
1 Parent(s): 9046320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -30
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import LlamaTokenizer, AutoModelForSequenceClassification
4
- from huggingface_hub import login
5
  import os
 
6
 
7
  # Hugging Face Authentication
8
  hf_token = os.getenv("HUGGINGFACE_TOKEN", "").strip()
@@ -14,59 +14,50 @@ if not hf_token:
14
  login(token=hf_token)
15
 
16
  # Load Model & Tokenizer
17
- model_name = "meta-llama/Llama-2-7b"
18
 
19
  @st.cache_resource
20
  def load_model():
21
  # Load tokenizer
22
  tokenizer = LlamaTokenizer.from_pretrained(model_name, token=hf_token)
23
 
24
- # Load model without bitsandbytes (full precision)
25
- model = AutoModelForSequenceClassification.from_pretrained(
26
  model_name,
27
  device_map="auto", # Automatically maps model to available devices
28
  token=hf_token
29
  )
30
 
31
- device = "cpu"
32
-
33
- return tokenizer, model, device
34
-
35
- tokenizer, model, device = load_model()
36
 
37
- # Define class labels
38
- class_labels = ["Negative", "Neutral", "Positive"]
39
 
40
- # Function to classify text
41
- def classify_text(user_input):
42
- inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True).to(device)
43
 
44
  with torch.no_grad():
45
- outputs = model(**inputs)
46
 
47
- logits = outputs.logits
48
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
49
- predicted_class_idx = torch.argmax(probabilities, dim=-1).item()
50
-
51
- return class_labels[predicted_class_idx], probabilities[0].cpu().tolist()
52
 
53
  # Streamlit UI
54
- st.title("πŸ“ Text Classification with LLaMA 2")
55
  st.write("Powered by LLaMA 2 & Hugging Face")
56
 
57
  # User Input
58
- user_input = st.text_area("Enter your text for classification:")
59
 
60
- if st.button("Classify"):
61
  if user_input:
62
- predicted_class, probs = classify_text(user_input)
63
 
64
  # Display result
65
- st.subheader(f"Predicted Class: {predicted_class}")
66
- st.write(f"Confidence Scores: {dict(zip(class_labels, probs))}")
67
-
68
  else:
69
- st.warning("Please enter some text to classify.")
70
 
71
  st.markdown("---")
72
- st.write("πŸ” This app classifies text using a fine-tuned LLaMA 2 model.")
 
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import LlamaTokenizer, AutoModelForCausalLM
 
4
  import os
5
+ from huggingface_hub import login
6
 
7
  # Hugging Face Authentication
8
  hf_token = os.getenv("HUGGINGFACE_TOKEN", "").strip()
 
14
  login(token=hf_token)
15
 
16
  # Load Model & Tokenizer
17
+ model_name = "meta-llama/LLaMA-2-7b" # Using the specified model
18
 
19
  @st.cache_resource
20
  def load_model():
21
  # Load tokenizer
22
  tokenizer = LlamaTokenizer.from_pretrained(model_name, token=hf_token)
23
 
24
+ # Load model for causal language modeling
25
+ model = AutoModelForCausalLM.from_pretrained(
26
  model_name,
27
  device_map="auto", # Automatically maps model to available devices
28
  token=hf_token
29
  )
30
 
31
+ return tokenizer, model
 
 
 
 
32
 
33
+ tokenizer, model = load_model()
 
34
 
35
+ # Function to generate text (example for usage)
36
+ def generate_text(prompt):
37
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
38
 
39
  with torch.no_grad():
40
+ outputs = model.generate(**inputs, max_length=50) # Adjust max_length as needed
41
 
42
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
43
 
44
  # Streamlit UI
45
+ st.title("πŸ“ Text Generation with LLaMA 2")
46
  st.write("Powered by LLaMA 2 & Hugging Face")
47
 
48
  # User Input
49
+ user_input = st.text_area("Enter your prompt:")
50
 
51
+ if st.button("Generate"):
52
  if user_input:
53
+ generated_text = generate_text(user_input)
54
 
55
  # Display result
56
+ st.subheader("Generated Text:")
57
+ st.write(generated_text)
 
58
  else:
59
+ st.warning("Please enter a prompt.")
60
 
61
  st.markdown("---")
62
+ st.write("πŸ” This app generates text using the LLaMA 2 model.")
63
+