SallySims commited on
Commit
ecb5b5d
·
verified ·
1 Parent(s): fb83f3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -37
app.py CHANGED
@@ -1,4 +1,5 @@
1
  ## Deploying on HuggingFace
 
2
  import streamlit as st
3
  import pandas as pd
4
  import torch
@@ -39,55 +40,46 @@ def load_model():
39
  raise e
40
  model, tokenizer = load_model()
41
 
 
42
  # Prediction function
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
  def get_prediction(prompt):
46
- # Log the received prompt
47
  st.write(f"Received prompt: {prompt}")
 
 
 
48
 
49
- # Add special tokens (if the model is expecting them)
50
- prompt_with_special_tokens = f"{prompt}<|eot_id|><|start_header_id|>"
 
 
 
 
 
51
 
52
- # Tokenize the input with the added special tokens
53
- inputs = tokenizer.encode(prompt_with_special_tokens, return_tensors="pt").to(device)
54
- st.write(f"Tokenized input: {inputs}") # Log the tokenized inputs
55
 
56
- # Ensure model is on the correct device (CUDA or CPU)
57
- model.to(device)
58
-
59
- # Generate output from the model
60
  output = model.generate(
61
- inputs,
62
- max_length=200, # Set a reasonable max length for output
63
- max_new_tokens=150, # Limit output to avoid too long generations
64
- temperature=0.7, # Control randomness
65
- top_p=0.95, # Top-p sampling for diversity
66
- do_sample=True, # Enable sampling (for more diverse answers)
67
- pad_token_id=tokenizer.eos_token_id, # Ensure padding is handled
68
- num_return_sequences=1 # Only generate 1 sequence
69
  )
70
 
71
- # Log the raw output from the model
72
- st.write(f"Raw output: {output}")
73
-
74
- # Decode the output to readable text
75
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
76
- st.write(f"Decoded output: {decoded}") # Log the decoded output
77
-
78
- # Ensure the output is properly formatted
79
- if "<|eot_id|>" in decoded:
80
- # If expected token is found, split the output
81
- decoded = decoded.split("<|eot_id|>")[-1].strip()
82
-
83
- return decoded
84
-
85
  st.write(f"Output: {output}") # Log the raw output from the model
86
- # Decode the output to readable text
87
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
88
- st.write(f"Decoded output: {decoded}") # Log the decoded output
89
- return decoded.strip()
90
-
91
 
92
 
93
  # UI Header
@@ -147,4 +139,3 @@ with tab2:
147
  csv_output = df.to_csv(index=False).encode("utf-8")
148
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
149
 
150
-
 
1
  ## Deploying on HuggingFace
2
+ ## Deploying on HuggingFace
3
  import streamlit as st
4
  import pandas as pd
5
  import torch
 
40
  raise e
41
  model, tokenizer = load_model()
42
 
43
+ # Prediction function
44
  # Prediction function
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
  def get_prediction(prompt):
 
48
  st.write(f"Received prompt: {prompt}")
49
+
50
+ # Create a message structure
51
+ messages = [{"role": "user", "content": prompt}]
52
 
53
+ # Tokenize the input
54
+ inputs = tokenizer.apply_chat_template(
55
+ messages,
56
+ tokenize=True,
57
+ add_generation_prompt=True, # This is needed for generation
58
+ return_tensors="pt",
59
+ ).to(device)
60
 
61
+ # Log the tokenized input
62
+ st.write(f"Tokenized input: {inputs}")
 
63
 
64
+ # Initialize TextStreamer for real-time streaming
65
+ text_streamer = TextStreamer(tokenizer)
66
+
67
+ # Generate output using the model with streaming
68
  output = model.generate(
69
+ inputs["input_ids"], # Use the tokenized input
70
+ max_new_tokens=150, # Limit the number of tokens
71
+ temperature=0.7, # Control randomness of output
72
+ top_p=0.95, # Sampling parameter
73
+ do_sample=True, # Ensure sampling for diverse output
74
+ streamer=text_streamer, # Use the TextStreamer for output
 
 
75
  )
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  st.write(f"Output: {output}") # Log the raw output from the model
78
+ # Decode the output
79
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
80
+ # Log decoded output
81
+ st.write(f"Decoded output: {decoded}")
82
+ return decoded
83
 
84
 
85
  # UI Header
 
139
  csv_output = df.to_csv(index=False).encode("utf-8")
140
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
141