SallySims commited on
Commit
3a9ffb7
·
verified ·
1 Parent(s): 62e15a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -44,11 +44,26 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
  def get_prediction(prompt):
46
  st.write(f"Received prompt: {prompt}") # Log the prompt received
 
47
  # Tokenize the input prompt
48
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
49
  st.write(f"Tokenized input: {inputs}") # Log the tokenized inputs
 
 
 
 
50
  # Generate output from the model
51
- output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
 
 
 
 
 
 
 
 
 
 
52
  st.write(f"Output: {output}") # Log the raw output from the model
53
  # Decode the output to readable text
54
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
@@ -56,6 +71,7 @@ def get_prediction(prompt):
56
  return decoded.strip()
57
 
58
 
 
59
  # UI Header
60
  st.title("🧠 AnthroBot")
61
  st.write("Enter your anthropometric estimates to receive an interpreted summary inputs — manually or via CSV upload.")
@@ -113,4 +129,3 @@ with tab2:
113
  csv_output = df.to_csv(index=False).encode("utf-8")
114
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
115
 
116
-
 
44
 
45
  def get_prediction(prompt):
46
  st.write(f"Received prompt: {prompt}") # Log the prompt received
47
+
48
  # Tokenize the input prompt
49
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
50
  st.write(f"Tokenized input: {inputs}") # Log the tokenized inputs
51
+
52
+ # Check if model is on the correct device
53
+ model.to(device)
54
+
55
  # Generate output from the model
56
+ output = model.generate(
57
+ inputs,
58
+ max_length=200, # Set a reasonable max length for output
59
+ max_new_tokens=150, # Limit output to avoid too long generations
60
+ temperature=0.7, # Control randomness
61
+ top_p=0.95, # Top-p sampling for diversity
62
+ do_sample=True, # Enable sampling (for more diverse answers)
63
+ pad_token_id=tokenizer.eos_token_id, # Ensure padding is handled
64
+ num_return_sequences=1 # Only generate 1 sequence
65
+ )
66
+
67
  st.write(f"Output: {output}") # Log the raw output from the model
68
  # Decode the output to readable text
69
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
 
71
  return decoded.strip()
72
 
73
 
74
+
75
  # UI Header
76
  st.title("🧠 AnthroBot")
77
  st.write("Enter your anthropometric estimates to receive an interpreted summary inputs — manually or via CSV upload.")
 
129
  csv_output = df.to_csv(index=False).encode("utf-8")
130
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
131