Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -44,20 +44,36 @@ def answer_question(text, question, max_length=512):
|
|
| 44 |
Returns:
|
| 45 |
The answer extracted from the text using the model.
|
| 46 |
"""
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
qa_model = TFBertForQuestionAnswering.from_pretrained(qa_model_name)
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
|
| 51 |
|
| 52 |
-
# Truncate text if necessary
|
| 53 |
if len(text) > max_length:
|
| 54 |
text = text[:max_length]
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
start_logits, end_logits = qa_model(inputs)
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
answer_end = tf.math.argmax(end_logits, axis=1) + 1 # Get predicted end index (exclusive)
|
| 62 |
|
| 63 |
answer = tf.gather(text, answer_start, axis=1).numpy()[0][answer_start[0]:answer_end[0]]
|
|
|
|
| 44 |
Returns:
|
| 45 |
The answer extracted from the text using the model.
|
| 46 |
"""
|
| 47 |
+
|
| 48 |
+
qa_model_name = "bert-base-uncased" # Replace with your model
|
| 49 |
|
| 50 |
qa_model = TFBertForQuestionAnswering.from_pretrained(qa_model_name)
|
| 51 |
tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
|
| 52 |
|
| 53 |
+
# Truncate text if necessary:
|
| 54 |
if len(text) > max_length:
|
| 55 |
text = text[:max_length]
|
| 56 |
|
| 57 |
+
# Add special tokens and tokenize:
|
| 58 |
+
inputs = tokenizer(
|
| 59 |
+
question, text, return_tensors="tf", padding="max_length", truncation=True
|
| 60 |
+
)
|
| 61 |
|
| 62 |
start_logits, end_logits = qa_model(inputs)
|
| 63 |
|
| 64 |
+
# Check and ensure data type of start_logits:
|
| 65 |
+
if start_logits.dtype not in (tf.float32, tf.int32):
|
| 66 |
+
start_logits = tf.cast(start_logits, tf.float32) # Example casting to float32
|
| 67 |
+
|
| 68 |
+
# Verify axis type:
|
| 69 |
+
if not isinstance(axis, tf.Tensor) or axis.dtype not in (tf.int32, tf.int64):
|
| 70 |
+
axis = tf.constant(1, dtype=tf.int32) # Replace with correct axis if needed
|
| 71 |
+
|
| 72 |
+
# Ensure compatibility for argmax (e.g., non-empty tensor):
|
| 73 |
+
if start_logits.shape[0] == 0:
|
| 74 |
+
raise ValueError("start_logits tensor is empty")
|
| 75 |
+
|
| 76 |
+
answer_start = tf.math.argmax(start_logits, axis=axis)
|
| 77 |
answer_end = tf.math.argmax(end_logits, axis=1) + 1 # Get predicted end index (exclusive)
|
| 78 |
|
| 79 |
answer = tf.gather(text, answer_start, axis=1).numpy()[0][answer_start[0]:answer_end[0]]
|