ShobhitKori commited on
Commit
17fdc74
·
1 Parent(s): 4df0d17

Updated codet5_model.py

Browse files
Files changed (2) hide show
  1. codet5_model.py +11 -3
  2. whisper_model.py +5 -4
codet5_model.py CHANGED
@@ -34,8 +34,11 @@
34
  # output = model.generate(**inputs, max_length=256)
35
  # return tokenizer.decode(output[0], skip_special_tokens=True)
36
 
 
37
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
38
 
 
 
39
  # Global variables but not initialized
40
  tokenizer = None
41
  model = None
@@ -43,9 +46,14 @@ model = None
43
  def load_model():
44
  global tokenizer, model
45
  if tokenizer is None or model is None:
46
- print("Loading CodeT5 model...")
47
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m-py")
48
- model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5p-770m-py")
 
 
 
 
 
49
  print("Model loaded.")
50
 
51
  def generate_code(instruction: str) -> str:
 
34
  # output = model.generate(**inputs, max_length=256)
35
  # return tokenizer.decode(output[0], skip_special_tokens=True)
36
 
37
+ import os
38
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
39
 
40
+ MODEL_DIR = "/data/codet5-model"
41
+
42
  # Global variables but not initialized
43
  tokenizer = None
44
  model = None
 
46
  def load_model():
47
  global tokenizer, model
48
  if tokenizer is None or model is None:
49
+ if not os.path.exists(MODEL_DIR):
50
+ print("Downloading CodeT5 model to persistent /data directory...")
51
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-770m-py", cache_dir=MODEL_DIR)
52
+ model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5p-770m-py", cache_dir=MODEL_DIR)
53
+ else:
54
+ print("Loading model from /data directory...")
55
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
56
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
57
  print("Model loaded.")
58
 
59
  def generate_code(instruction: str) -> str:
whisper_model.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
 
 
2
 
3
  # Add FFmpeg directory to PATH
4
- os.environ["PATH"] += os.pathsep + r"C:\ffmpeg\bin"
5
 
6
- import whisper
7
- import logging
8
 
9
- model = whisper.load_model("small")
10
 
11
  # def transcribe_audio(file_path: str) -> str:
12
  # result = model.transcribe(file_path)
 
1
  import os
2
+ import whisper
3
+ import logging
4
 
5
  # Add FFmpeg directory to PATH
6
+ # os.environ["PATH"] += os.pathsep + r"C:\ffmpeg\bin"
7
 
8
+ WHISPER_MODEL_DIR = "/data/whisper-small"
 
9
 
10
+ model = whisper.load_model("small", download_root=WHISPER_MODEL_DIR)
11
 
12
  # def transcribe_audio(file_path: str) -> str:
13
  # result = model.transcribe(file_path)