SallySims commited on
Commit
62e15a3
·
verified ·
1 Parent(s): e80f308

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -43,13 +43,17 @@ model, tokenizer = load_model()
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
  def get_prediction(prompt):
46
- messages = [{"role": "user", "content": prompt}]
47
- inputs = tokenizer.apply_chat_template(
48
- messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
49
- ).to(device)
 
50
  output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
 
 
51
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
52
- return decoded.split("###")[-1].strip()
 
53
 
54
 
55
  # UI Header
@@ -108,3 +112,5 @@ with tab2:
108
 
109
  csv_output = df.to_csv(index=False).encode("utf-8")
110
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
 
 
 
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
  def get_prediction(prompt):
46
+ st.write(f"Received prompt: {prompt}") # Log the prompt received
47
+ # Tokenize the input prompt
48
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
49
+ st.write(f"Tokenized input: {inputs}") # Log the tokenized inputs
50
+ # Generate output from the model
51
  output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
52
+ st.write(f"Output: {output}") # Log the raw output from the model
53
+ # Decode the output to readable text
54
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
55
+ st.write(f"Decoded output: {decoded}") # Log the decoded output
56
+ return decoded.strip()
57
 
58
 
59
  # UI Header
 
112
 
113
  csv_output = df.to_csv(index=False).encode("utf-8")
114
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
115
+
116
+