import error corrected
Browse files- model_functions.py +2 -1
model_functions.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
import torch
|
| 2 |
from torch.nn import CrossEntropyLoss, MSELoss
|
| 3 |
import re
|
|
|
|
| 4 |
#!pip install rouge_score
|
| 5 |
|
| 6 |
#from rouge_score import rouge_scorer
|
| 7 |
|
| 8 |
from nltk.translate.meteor_score import meteor_score
|
| 9 |
-
|
| 10 |
|
| 11 |
def forward_batch(images, input_ids, attention_mask, answers, question_classes=None,qtype_classifier=None,fusion_module=None,q_types=None,q_types_mapping=None,task_heads=None,device=None,image_encoder=None,question_encoder=None):
|
| 12 |
# Image encoding
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch.nn import CrossEntropyLoss, MSELoss
|
| 3 |
import re
|
| 4 |
+
from models import disease_model
|
| 5 |
#!pip install rouge_score
|
| 6 |
|
| 7 |
#from rouge_score import rouge_scorer
|
| 8 |
|
| 9 |
from nltk.translate.meteor_score import meteor_score
|
| 10 |
+
|
| 11 |
|
| 12 |
def forward_batch(images, input_ids, attention_mask, answers, question_classes=None,qtype_classifier=None,fusion_module=None,q_types=None,q_types_mapping=None,task_heads=None,device=None,image_encoder=None,question_encoder=None):
|
| 13 |
# Image encoding
|