Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import torch | |
| from huggingface_hub import login | |
| from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer | |
| from vllm import LLM, SamplingParams | |
| login(token=os.getenv('HF_TOKEN')) | |
| class Model(torch.nn.Module): | |
| number_of_models = 0 | |
| __model_list__ = [ | |
| "Qwen/Qwen2-1.5B-Instruct", | |
| "lmsys/vicuna-7b-v1.5", | |
| "google-t5/t5-large", | |
| "mistralai/Mistral-7B-Instruct-v0.1", | |
| "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| ] | |
| def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None: | |
| super(Model, self).__init__() | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.name = model_name | |
| self.use_vllm = model_name != "google-t5/t5-large" | |
| logging.info(f'Start loading model {self.name}') | |
| if self.use_vllm: | |
| # 使用vLLM加载模型 | |
| self.llm = LLM( | |
| model=model_name, | |
| dtype="bfloat16", | |
| tokenizer=model_name, | |
| trust_remote_code=True | |
| ) | |
| else: | |
| # 加载原始transformers模型 | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| self.model.eval() | |
| logging.info(f'Loaded model {self.name}') | |
| self.update() | |
| def update(cls): | |
| cls.number_of_models += 1 | |
| def gen(self, content_list, temp=0.001, max_length=500, do_sample=True): | |
| if self.use_vllm: | |
| sampling_params = SamplingParams( | |
| temperature=temp, | |
| max_tokens=max_length, | |
| top_p=0.95 if do_sample else 1.0, | |
| stop_token_ids=[self.tokenizer.eos_token_id] | |
| ) | |
| outputs = self.llm.generate(content_list, sampling_params) | |
| return [output.outputs[0].text for output in outputs] | |
| else: | |
| input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device) | |
| outputs = self.model.generate( | |
| input_ids, | |
| max_new_tokens=max_length, | |
| do_sample=do_sample, | |
| temperature=temp, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| return self.tokenizer.batch_decode(outputs[:, input_ids.shape[1]:], skip_special_tokens=True) | |
| def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True): | |
| if self.use_vllm: | |
| sampling_params = SamplingParams( | |
| temperature=temp, | |
| max_tokens=max_length, | |
| top_p=0.95 if do_sample else 1.0, | |
| stop_token_ids=[self.tokenizer.eos_token_id] | |
| ) | |
| outputs = self.llm.generate(content_list, sampling_params, stream=True) | |
| prev_token_ids = [[] for _ in content_list] | |
| for output in outputs: | |
| for i, request_output in enumerate(output.outputs): | |
| current_token_ids = request_output.token_ids | |
| new_token_ids = current_token_ids[len(prev_token_ids[i]):] | |
| prev_token_ids[i] = current_token_ids.copy() | |
| for token_id in new_token_ids: | |
| token_text = self.tokenizer.decode(token_id, skip_special_tokens=True) | |
| yield i, token_text | |
| else: | |
| input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device) | |
| gen_kwargs = { | |
| "input_ids": input_ids, | |
| "do_sample": do_sample, | |
| "temperature": temp, | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| "max_new_tokens": 1, | |
| "return_dict_in_generate": True, | |
| "output_scores": True | |
| } | |
| generated_tokens = 0 | |
| batch_size = input_ids.shape[0] | |
| active_sequences = torch.arange(batch_size) | |
| while generated_tokens < max_length and len(active_sequences) > 0: | |
| with torch.no_grad(): | |
| output = self.model.generate(**gen_kwargs) | |
| next_tokens = output.sequences[:, -1].unsqueeze(-1) | |
| for i, token in zip(active_sequences, next_tokens): | |
| yield i.item(), self.tokenizer.decode(token[0], skip_special_tokens=True) | |
| gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1) | |
| generated_tokens += 1 | |
| completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1) | |
| active_sequences = torch.tensor([i for i in active_sequences if i not in completed]) | |
| if len(active_sequences) > 0: | |
| gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences] |