Rishabh2234 commited on
Commit
a846416
·
1 Parent(s): 72ca0c5

files for inference generation

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. model.py +11 -0
app.py CHANGED
@@ -4,6 +4,7 @@ from PIL import Image
4
  import torchvision.transforms as transforms
5
  import torch
6
  from model import load_model
 
7
 
8
  # Initialize FastAPI
9
  app = FastAPI()
 
4
  import torchvision.transforms as transforms
5
  import torch
6
  from model import load_model
7
+ import os
8
 
9
  # Initialize FastAPI
10
  app = FastAPI()
model.py CHANGED
@@ -1,10 +1,21 @@
1
  import os
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
5
  from transformers.modeling_outputs import BaseModelOutput
6
  import requests
7
 
 
 
8
  class ViTT5(nn.Module):
9
  def __init__(self, vit_encoder, t5_decoder):
10
  super(ViTT5, self).__init__()
 
1
  import os
2
+
3
+
4
+
5
+ # Set cache directory to a writable location within your app's directory
6
+ os.environ["TRANSFORMERS_CACHE"] = "./cache"
7
+ os.makedirs("./cache", exist_ok=True)
8
+
9
+
10
+
11
  import torch
12
  import torch.nn as nn
13
  from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
14
  from transformers.modeling_outputs import BaseModelOutput
15
  import requests
16
 
17
+
18
+
19
  class ViTT5(nn.Module):
20
  def __init__(self, vit_encoder, t5_decoder):
21
  super(ViTT5, self).__init__()