SallySims commited on
Commit
250d65c
·
verified ·
1 Parent(s): e6456b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -1,5 +1,4 @@
1
  ## Deploying on HuggingFace
2
-
3
  import streamlit as st
4
  import pandas as pd
5
  import torch
@@ -56,35 +55,45 @@ def get_prediction(prompt):
56
  return_tensors="pt",
57
  max_length=512,
58
  truncation=True
59
- ).to(device)
60
  except Exception as e:
61
- st.error(f"Error during tokenization: {str(e)}")
62
- return None
63
-
 
 
 
 
 
 
 
64
  # Debug: Log inputs structure
65
  st.write(f"Inputs type: {type(inputs)}")
66
- if isinstance(inputs, dict):
67
- st.write(f"Inputs keys: {list(inputs.keys())}")
68
- if 'input_ids' in inputs:
69
- st.write(f"Input IDs shape: {inputs['input_ids'].shape}")
70
- else:
71
- st.error("No 'input_ids' in tokenized inputs")
72
- return None
 
 
 
 
 
 
 
 
 
 
73
  else:
74
  st.error(f"Unexpected inputs format: {type(inputs)}")
75
  return None
76
 
77
- # Extract input_ids safely
78
- input_ids = inputs['input_ids']
79
- if len(input_ids.shape) == 3 and input_ids.shape[0] == 1:
80
- input_ids = input_ids.squeeze(0) # Remove batch dimension if 3D
81
- elif len(input_ids.shape) == 2:
82
- pass # Already 2D, no squeeze needed
83
- else:
84
- st.error(f"Invalid input_ids shape: {input_ids.shape}")
85
- return None
86
 
87
- st.write(f"Final input_ids shape: {input_ids.shape}")
 
88
 
89
  # Generate output
90
  try:
 
1
  ## Deploying on HuggingFace
 
2
  import streamlit as st
3
  import pandas as pd
4
  import torch
 
55
  return_tensors="pt",
56
  max_length=512,
57
  truncation=True
58
+ )
59
  except Exception as e:
60
+ st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.")
61
+ # Fallback: Manual tokenization
62
+ inputs = tokenizer(
63
+ prompt,
64
+ return_tensors="pt",
65
+ max_length=512,
66
+ truncation=True,
67
+ padding=False
68
+ )
69
+
70
  # Debug: Log inputs structure
71
  st.write(f"Inputs type: {type(inputs)}")
72
+
73
+ # Handle inputs (tensor or dict)
74
+ if isinstance(inputs, torch.Tensor):
75
+ # Direct tensor (likely input_ids)
76
+ input_ids = inputs
77
+ if len(input_ids.shape) == 1:
78
+ input_ids = input_ids.unsqueeze(0) # Add batch dimension: [sequence_length] -> [1, sequence_length]
79
+ elif len(input_ids.shape) > 2:
80
+ input_ids = input_ids.squeeze() # Remove extra dimensions if any
81
+ if len(input_ids.shape) == 1:
82
+ input_ids = input_ids.unsqueeze(0)
83
+ elif isinstance(inputs, dict) and 'input_ids' in inputs:
84
+ input_ids = inputs['input_ids']
85
+ if len(input_ids.shape) == 3 and input_ids.shape[0] == 1:
86
+ input_ids = input_ids.squeeze(0)
87
+ elif len(input_ids.shape) == 1:
88
+ input_ids = input_ids.unsqueeze(0)
89
  else:
90
  st.error(f"Unexpected inputs format: {type(inputs)}")
91
  return None
92
 
93
+ st.write(f"Input IDs shape: {input_ids.shape}")
 
 
 
 
 
 
 
 
94
 
95
+ # Ensure input_ids is on the correct device
96
+ input_ids = input_ids.to(device)
97
 
98
  # Generate output
99
  try: