Spaces:
Build error
Build error
zhenyundeng
commited on
Commit
·
f58b8d2
1
Parent(s):
c052247
update app.py
Browse files
app.py
CHANGED
|
@@ -85,13 +85,9 @@ veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_chec
|
|
| 85 |
# Justification
|
| 86 |
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
| 87 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
| 88 |
-
best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
| 89 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
|
| 90 |
# justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
print("veracity_model_device_0:{}".format(veracity_model.device))
|
| 94 |
-
print("justification_model_device_0:{}".format(justification_model.device))
|
| 95 |
# ---------------------------------------------------------------------------
|
| 96 |
|
| 97 |
# ----------------------------------------------------------------------------
|
|
@@ -285,9 +281,8 @@ def veracity_prediction(claim, evidence):
|
|
| 285 |
return pred_label
|
| 286 |
|
| 287 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
| 288 |
-
example_support = torch.argmax(veracity_model(tokenized_strings, attention_mask=attention_mask).logits, axis=1)
|
| 289 |
# example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
| 290 |
-
print("veracity_model_device_1:{}".format(veracity_model.device))
|
| 291 |
|
| 292 |
has_unanswerable = False
|
| 293 |
has_true = False
|
|
@@ -349,9 +344,8 @@ def justification_generation(claim, evidence, verdict_label):
|
|
| 349 |
#
|
| 350 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
| 351 |
claim_str.strip()
|
| 352 |
-
pred_justification = justification_model.generate(claim_str)
|
| 353 |
# pred_justification = justification_model.generate(claim_str, device=device)
|
| 354 |
-
print("justification_model_device_1:{}".format(justification_model.device))
|
| 355 |
|
| 356 |
return pred_justification.strip()
|
| 357 |
|
|
|
|
| 85 |
# Justification
|
| 86 |
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
| 87 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
| 88 |
+
best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
| 89 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
|
| 90 |
# justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# ---------------------------------------------------------------------------
|
| 92 |
|
| 93 |
# ----------------------------------------------------------------------------
|
|
|
|
| 281 |
return pred_label
|
| 282 |
|
| 283 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
| 284 |
+
example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
|
| 285 |
# example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
|
|
|
| 286 |
|
| 287 |
has_unanswerable = False
|
| 288 |
has_true = False
|
|
|
|
| 344 |
#
|
| 345 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
| 346 |
claim_str.strip()
|
| 347 |
+
pred_justification = justification_model.generate(claim_str, device='cuda')
|
| 348 |
# pred_justification = justification_model.generate(claim_str, device=device)
|
|
|
|
| 349 |
|
| 350 |
return pred_justification.strip()
|
| 351 |
|