Evan-Lin commited on
Commit
6d60d62
·
1 Parent(s): 15306ce

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -5
handler.py CHANGED
@@ -16,6 +16,7 @@ class EndpointHandler():
16
  # self.pipeline = pipeline(task= "automatic-speech-recognition", model=path)
17
  # self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
18
  # self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
 
19
  peft_model_id = "cathyi/openai-whisper-large-v2-Lora"
20
  language = "Chinese"
21
  task = "transcribe"
@@ -23,13 +24,13 @@ class EndpointHandler():
23
  model = WhisperForConditionalGeneration.from_pretrained(
24
  peft_config.base_model_name_or_path
25
  )
26
- model = PeftModel.from_pretrained(model, peft_model_id)
27
  tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
28
  processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
29
  feature_extractor = processor.feature_extractor
30
  self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
31
- self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
32
- # self.pipeline = pipeline(task= "automatic-speech-recognition", model=path, tokenizer=tokenizer, feature_extractor = feature_extractor)
33
 
34
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
35
  """
@@ -46,8 +47,12 @@ class EndpointHandler():
46
  print("a1", inputs)
47
  print("a2", inputs, file=sys.stderr)
48
  print("a3", inputs, file=sys.stdout)
49
- prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
50
- # prediction = self.pipeline(inputs, return_timestamps=False)
 
 
 
 
51
  print("b1", prediction)
52
  print("b2", predcition, file=sys.stderr)
53
  print("b3", predcition, file=sys.stdout)
 
16
  # self.pipeline = pipeline(task= "automatic-speech-recognition", model=path)
17
  # self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
18
  # self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
19
+
20
  peft_model_id = "cathyi/openai-whisper-large-v2-Lora"
21
  language = "Chinese"
22
  task = "transcribe"
 
24
  model = WhisperForConditionalGeneration.from_pretrained(
25
  peft_config.base_model_name_or_path
26
  )
27
+ model = PeftModel.from_pretrained(model, path)
28
  tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
29
  processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
30
  feature_extractor = processor.feature_extractor
31
  self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
32
+ # self.pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
33
+ self.pipeline = pipeline(task= "automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor = feature_extractor)
34
 
35
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
36
  """
 
47
  print("a1", inputs)
48
  print("a2", inputs, file=sys.stderr)
49
  print("a3", inputs, file=sys.stdout)
50
+ try:
51
+ # prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
52
+ prediction = self.pipeline(inputs, return_timestamps=False)
53
+ except :
54
+ print("error")
55
+
56
  print("b1", prediction)
57
  print("b2", predcition, file=sys.stderr)
58
  print("b3", predcition, file=sys.stdout)