adinarayana commited on
Commit
1465c8f
·
verified ·
1 Parent(s): 190a44b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -44,20 +44,36 @@ def answer_question(text, question, max_length=512):
44
  Returns:
45
  The answer extracted from the text using the model.
46
  """
47
- qa_model_name = "bert-base-uncased" # Replace with your chosen model
 
48
 
49
  qa_model = TFBertForQuestionAnswering.from_pretrained(qa_model_name)
50
  tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
51
 
52
- # Truncate text if necessary
53
  if len(text) > max_length:
54
  text = text[:max_length]
55
 
56
- inputs = tokenizer(question, text, return_tensors="tf") # Tokenize inputs for TensorFlow
 
 
 
57
 
58
  start_logits, end_logits = qa_model(inputs)
59
 
60
- answer_start = tf.math.argmax(start_logits, axis=1) # Get predicted start index
 
 
 
 
 
 
 
 
 
 
 
 
61
  answer_end = tf.math.argmax(end_logits, axis=1) + 1 # Get predicted end index (exclusive)
62
 
63
  answer = tf.gather(text, answer_start, axis=1).numpy()[0][answer_start[0]:answer_end[0]]
 
44
  Returns:
45
  The answer extracted from the text using the model.
46
  """
47
+
48
+ qa_model_name = "bert-base-uncased" # Replace with your model
49
 
50
  qa_model = TFBertForQuestionAnswering.from_pretrained(qa_model_name)
51
  tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
52
 
53
+ # Truncate text if necessary:
54
  if len(text) > max_length:
55
  text = text[:max_length]
56
 
57
+ # Add special tokens and tokenize:
58
+ inputs = tokenizer(
59
+ question, text, return_tensors="tf", padding="max_length", truncation=True
60
+ )
61
 
62
  start_logits, end_logits = qa_model(inputs)
63
 
64
+ # Check and ensure data type of start_logits:
65
+ if start_logits.dtype not in (tf.float32, tf.int32):
66
+ start_logits = tf.cast(start_logits, tf.float32) # Example casting to float32
67
+
68
+ # Verify axis type:
69
+ if not isinstance(axis, tf.Tensor) or axis.dtype not in (tf.int32, tf.int64):
70
+ axis = tf.constant(1, dtype=tf.int32) # Replace with correct axis if needed
71
+
72
+ # Ensure compatibility for argmax (e.g., non-empty tensor):
73
+ if start_logits.shape[0] == 0:
74
+ raise ValueError("start_logits tensor is empty")
75
+
76
+ answer_start = tf.math.argmax(start_logits, axis=axis)
77
  answer_end = tf.math.argmax(end_logits, axis=1) + 1 # Get predicted end index (exclusive)
78
 
79
  answer = tf.gather(text, answer_start, axis=1).numpy()[0][answer_start[0]:answer_end[0]]