Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
| 3 |
import json
|
|
|
|
| 4 |
|
| 5 |
# Define hyperparameters
|
| 6 |
-
learning_rate = 3e-5
|
| 7 |
-
batch_size =
|
| 8 |
-
epochs =
|
| 9 |
-
max_seq_length =
|
| 10 |
-
warmup_steps =
|
| 11 |
-
weight_decay = 0.01
|
| 12 |
-
dropout_prob = 0.
|
| 13 |
-
gradient_clip_value = 1.0
|
| 14 |
|
| 15 |
context_val = ''
|
| 16 |
|
|
@@ -35,10 +36,19 @@ def q_n_a_fn(context, text):
|
|
| 35 |
with torch.no_grad():
|
| 36 |
outputs = q_n_a_model(**inputs)
|
| 37 |
|
| 38 |
-
#
|
| 39 |
start_idx, end_idx = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return answer
|
| 43 |
|
| 44 |
def classification_fn(text):
|
|
@@ -75,4 +85,4 @@ with gr.Blocks(theme='gradio/soft') as demo:
|
|
| 75 |
gr.Interface(fn=classification_fn, inputs=[context], outputs="text")
|
| 76 |
|
| 77 |
if __name__ == "__main__":
|
| 78 |
-
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
| 3 |
import json
|
| 4 |
+
import torch
|
| 5 |
|
| 6 |
# Define hyperparameters
|
| 7 |
+
learning_rate = 3e-5 # Slightly lower learning rate
|
| 8 |
+
batch_size = 8 # Smaller batch size to allow for more precise updates
|
| 9 |
+
epochs = 4 # Slightly more training epochs
|
| 10 |
+
max_seq_length = 256 # Smaller sequence length, especially if the majority of your questions and contexts are shorter
|
| 11 |
+
warmup_steps = 200 # Longer warmup phase
|
| 12 |
+
weight_decay = 0.01 # Keep weight decay as it is
|
| 13 |
+
dropout_prob = 0.2 # Slightly higher dropout for regularization
|
| 14 |
+
gradient_clip_value = 1.0 # Keep gradient clip value as it is
|
| 15 |
|
| 16 |
context_val = ''
|
| 17 |
|
|
|
|
| 36 |
with torch.no_grad():
|
| 37 |
outputs = q_n_a_model(**inputs)
|
| 38 |
|
| 39 |
+
# Get the predicted answer span indices
|
| 40 |
start_idx, end_idx = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
|
| 41 |
+
|
| 42 |
+
# Ensure indices are within bounds
|
| 43 |
+
start_idx = min(start_idx, len(inputs["input_ids"][0]) - 1)
|
| 44 |
+
end_idx = min(end_idx, len(inputs["input_ids"][0]) - 1)
|
| 45 |
+
|
| 46 |
+
# Find the answer tokens in the input
|
| 47 |
+
answer_tokens = inputs["input_ids"][0][start_idx : end_idx + 1]
|
| 48 |
+
|
| 49 |
+
# Decode the answer tokens into a human-readable answer
|
| 50 |
+
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
| 51 |
+
|
| 52 |
return answer
|
| 53 |
|
| 54 |
def classification_fn(text):
|
|
|
|
| 85 |
gr.Interface(fn=classification_fn, inputs=[context], outputs="text")
|
| 86 |
|
| 87 |
if __name__ == "__main__":
|
| 88 |
+
demo.launch()
|