SallySims commited on
Commit
38f3073
·
verified ·
1 Parent(s): c8f9c5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -33,15 +33,18 @@ def load_model():
33
  model, tokenizer = load_model()
34
 
35
  # Prediction function
 
 
36
  def get_prediction(prompt):
37
  messages = [{"role": "user", "content": prompt}]
38
  inputs = tokenizer.apply_chat_template(
39
  messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
40
- ).to("cuda")
41
  output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
42
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
43
  return decoded.split("###")[-1].strip()
44
 
 
45
  # UI Header
46
  st.title("🧠 AnthroBot")
47
  st.write("Enter your anthropometric estimates to receive an interpreted summary inputs — manually or via CSV upload.")
 
33
  model, tokenizer = load_model()
34
 
35
  # Prediction function
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+
38
  def get_prediction(prompt):
39
  messages = [{"role": "user", "content": prompt}]
40
  inputs = tokenizer.apply_chat_template(
41
  messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
42
+ ).to(device)
43
  output = model.generate(inputs, max_new_tokens=150, temperature=0.7, top_p=0.95)
44
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
45
  return decoded.split("###")[-1].strip()
46
 
47
+
48
  # UI Header
49
  st.title("🧠 AnthroBot")
50
  st.write("Enter your anthropometric estimates to receive an interpreted summary inputs — manually or via CSV upload.")