SallySims commited on
Commit
fb83f3f
·
verified ·
1 Parent(s): 3a9ffb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -43,13 +43,17 @@ model, tokenizer = load_model()
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
  def get_prediction(prompt):
46
- st.write(f"Received prompt: {prompt}") # Log the prompt received
 
47
 
48
- # Tokenize the input prompt
49
- inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
 
50
  st.write(f"Tokenized input: {inputs}") # Log the tokenized inputs
51
 
52
- # Check if model is on the correct device
53
  model.to(device)
54
 
55
  # Generate output from the model
@@ -64,6 +68,20 @@ def get_prediction(prompt):
64
  num_return_sequences=1 # Only generate 1 sequence
65
  )
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  st.write(f"Output: {output}") # Log the raw output from the model
68
  # Decode the output to readable text
69
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
@@ -129,3 +147,4 @@ with tab2:
129
  csv_output = df.to_csv(index=False).encode("utf-8")
130
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
131
 
 
 
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
  def get_prediction(prompt):
46
+ # Log the received prompt
47
+ st.write(f"Received prompt: {prompt}")
48
 
49
+ # Add special tokens (if the model is expecting them)
50
+ prompt_with_special_tokens = f"{prompt}<|eot_id|><|start_header_id|>"
51
+
52
+ # Tokenize the input with the added special tokens
53
+ inputs = tokenizer.encode(prompt_with_special_tokens, return_tensors="pt").to(device)
54
  st.write(f"Tokenized input: {inputs}") # Log the tokenized inputs
55
 
56
+ # Ensure model is on the correct device (CUDA or CPU)
57
  model.to(device)
58
 
59
  # Generate output from the model
 
68
  num_return_sequences=1 # Only generate 1 sequence
69
  )
70
 
71
+ # Log the raw output from the model
72
+ st.write(f"Raw output: {output}")
73
+
74
+ # Decode the output to readable text
75
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
76
+ st.write(f"Decoded output: {decoded}") # Log the decoded output
77
+
78
+ # Ensure the output is properly formatted
79
+ if "<|eot_id|>" in decoded:
80
+ # If expected token is found, split the output
81
+ decoded = decoded.split("<|eot_id|>")[-1].strip()
82
+
83
+ return decoded
84
+
85
  st.write(f"Output: {output}") # Log the raw output from the model
86
  # Decode the output to readable text
87
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
 
147
  csv_output = df.to_csv(index=False).encode("utf-8")
148
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
149
 
150
+