SallySims commited on
Commit
62d3ad2
·
verified ·
1 Parent(s): 7a0a8fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -1,4 +1,5 @@
1
  ## Deploying on HuggingFace
 
2
  import streamlit as st
3
  import pandas as pd
4
  import torch
@@ -7,6 +8,7 @@ from huggingface_hub import login
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
8
  from peft import PeftModel, PeftConfig
9
  import io
 
10
 
11
  # Login using Hugging Face token
12
  login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
@@ -58,7 +60,7 @@ def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
58
  return_tensors="pt",
59
  max_length=512,
60
  truncation=True,
61
- return_dict=True # Ensure dictionary output
62
  )
63
  except Exception as e:
64
  st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
@@ -73,15 +75,16 @@ def get_prediction(age, sex, height_cm, weight_kg, wc_cm):
73
 
74
  # Debug: Log inputs structure
75
  st.write(f"Inputs type: {type(inputs)}")
 
76
 
77
- # Handle inputs (tensor or dict)
78
  if isinstance(inputs, torch.Tensor):
79
- # Assume tensor is input_ids, create attention_mask
80
  input_ids = inputs
 
81
  if len(input_ids.shape) == 1:
82
  input_ids = input_ids.unsqueeze(0)
83
- attention_mask = torch.ones_like(input_ids) # Create attention_mask
84
- elif isinstance(inputs, dict) and 'input_ids' in inputs:
85
  input_ids = inputs['input_ids']
86
  attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids))
87
  if len(input_ids.shape) == 3 and input_ids.shape[0] == 1:
 
1
  ## Deploying on HuggingFace
2
+
3
  import streamlit as st
4
  import pandas as pd
5
  import torch
 
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
9
  from peft import PeftModel, PeftConfig
10
  import io
11
+ from transformers.tokenization_utils_base import BatchEncoding
12
 
13
  # Login using Hugging Face token
14
  login(token=os.getenv("HUGGINGFACEHUB_TOKEN"))
 
60
  return_tensors="pt",
61
  max_length=512,
62
  truncation=True,
63
+ return_dict=True
64
  )
65
  except Exception as e:
66
  st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
 
75
 
76
  # Debug: Log inputs structure
77
  st.write(f"Inputs type: {type(inputs)}")
78
+ st.write(f"Inputs content: {inputs}")
79
 
80
+ # Handle inputs (tensor, dict, or BatchEncoding)
81
  if isinstance(inputs, torch.Tensor):
 
82
  input_ids = inputs
83
+ attention_mask = torch.ones_like(input_ids)
84
  if len(input_ids.shape) == 1:
85
  input_ids = input_ids.unsqueeze(0)
86
+ attention_mask = attention_mask.unsqueeze(0)
87
+ elif isinstance(inputs, (dict, BatchEncoding)):
88
  input_ids = inputs['input_ids']
89
  attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids))
90
  if len(input_ids.shape) == 3 and input_ids.shape[0] == 1: