SallySims commited on
Commit
7d1653c
·
verified ·
1 Parent(s): 199d0ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -1,5 +1,6 @@
1
  ## Deploying on HuggingFace
2
 
 
3
  import streamlit as st
4
  import pandas as pd
5
  import torch
@@ -57,28 +58,32 @@ def get_prediction(prompt):
57
  add_generation_prompt=True, # This is needed for generation
58
  return_tensors="pt",
59
  ).to(device)
60
-
61
  # Log the tokenized input
62
  st.write(f"Tokenized input: {inputs}")
63
 
64
- # Initialize TextStreamer for real-time streaming
65
- text_streamer = TextStreamer(tokenizer)
66
 
67
- # Generate output using the model with streaming
 
 
 
 
68
  output = model.generate(
69
- inputs["input_ids"], # Use the tokenized input
70
- max_new_tokens=250, # Limit the number of tokens
71
  temperature=0.7, # Control randomness of output
72
  top_p=0.95, # Sampling parameter
73
  do_sample=True, # Ensure sampling for diverse output
74
- streamer=text_streamer, # Use the TextStreamer for output
75
  )
76
 
77
- st.write(f"Output: {output}") # Log the raw output from the model
78
  # Decode the output
79
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
80
- # Log decoded output
 
81
  st.write(f"Decoded output: {decoded}")
 
82
  return decoded
83
 
84
 
 
1
  ## Deploying on HuggingFace
2
 
3
+
4
  import streamlit as st
5
  import pandas as pd
6
  import torch
 
58
  add_generation_prompt=True, # This is needed for generation
59
  return_tensors="pt",
60
  ).to(device)
61
+
62
  # Log the tokenized input
63
  st.write(f"Tokenized input: {inputs}")
64
 
65
+ # Verify the shape of the tokenized input
66
+ st.write(f"Shape of tokenized input: {inputs['input_ids'].shape}")
67
 
68
+ # Ensure that input_ids has the correct shape
69
+ input_ids = inputs["input_ids"].squeeze(0) # Remove the batch dimension if it's there
70
+ st.write(f"Corrected tokenized input shape: {input_ids.shape}")
71
+
72
+ # Generate output using the model
73
  output = model.generate(
74
+ input_ids, # Use the tokenized input
75
+ max_new_tokens=150, # Limit the number of tokens
76
  temperature=0.7, # Control randomness of output
77
  top_p=0.95, # Sampling parameter
78
  do_sample=True, # Ensure sampling for diverse output
 
79
  )
80
 
 
81
  # Decode the output
82
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
83
+
84
+ # Log the decoded output
85
  st.write(f"Decoded output: {decoded}")
86
+
87
  return decoded
88
 
89