Evan-Lin commited on
Commit
d23f124
·
1 Parent(s): f834f45

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -31
handler.py CHANGED
@@ -11,35 +11,35 @@ WhisperProcessor
11
  from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
12
 
13
  class EndpointHandler():
14
- def __init__(self, path=""):
15
- # self.pipeline = pipeline(task= "automatic-speech-recognition", model=path)
16
- # self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
17
- # self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
18
- peft_model_id = "cathyi/openai-whisper-large-v2-Lora"
19
- language = "Chinese"
20
- task = "transcribe"
21
- peft_config = PeftConfig.from_pretrained(peft_model_id)
22
- model = WhisperForConditionalGeneration.from_pretrained(
23
- peft_config.base_model_name_or_path
24
- )
25
- model = PeftModel.from_pretrained(model, peft_model_id)
26
- tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
27
- processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
28
- feature_extractor = processor.feature_extractor
29
- self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
30
- pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
31
-
32
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
33
- """
34
- data args:
35
- inputs (:obj: `str`)
36
- date (:obj: `str`)
37
- Return:
38
- A :obj:`list` | `dict`: will be serialized and returned
39
- """
40
- # get inputs
41
 
42
- # run normal prediction
43
- inputs = data.pop("inputs", data)
44
- prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
45
- return prediction
 
 
 
 
 
 
 
 
 
 
 
11
  from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
12
 
13
  class EndpointHandler():
14
+ def __init__(self, path=""):
15
+ # self.pipeline = pipeline(task= "automatic-speech-recognition", model=path)
16
+ # self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
17
+ # self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids # just to be sure!
18
+ peft_model_id = "cathyi/openai-whisper-large-v2-Lora"
19
+ language = "Chinese"
20
+ task = "transcribe"
21
+ peft_config = PeftConfig.from_pretrained(path)
22
+ model = WhisperForConditionalGeneration.from_pretrained(
23
+ peft_config.base_model_name_or_path
24
+ )
25
+ model = PeftModel.from_pretrained(model, peft_model_id)
26
+ tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
27
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
28
+ feature_extractor = processor.feature_extractor
29
+ self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
30
+ pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
33
+ """
34
+ data args:
35
+ inputs (:obj: `str`)
36
+ date (:obj: `str`)
37
+ Return:
38
+ A :obj:`list` | `dict`: will be serialized and returned
39
+ """
40
+ # get inputs
41
+
42
+ # run normal prediction
43
+ inputs = data.pop("inputs", data)
44
+ prediction = self.pipeline(inputs, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids}, max_new_tokens=255)["text"]
45
+ return prediction