Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Chameleon License found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import io | |
| import json | |
| from typing import Generator | |
| import PIL.Image | |
| import torch | |
| import transformers | |
| from tokenizers import Tokenizer | |
| from transformers import ( | |
| MaxLengthCriteria, | |
| RepetitionPenaltyLogitsProcessor, | |
| TemperatureLogitsWarper, | |
| TopPLogitsWarper, | |
| ) | |
| from chameleon.inference.alignment import AlignPromptRight | |
| from chameleon.inference.generation import ChameleonGenerator | |
| from chameleon.inference.image_tokenizer import ImageTokenizer | |
| from chameleon.inference.loader import load_model | |
| from chameleon.inference.logits_processor import ( | |
| AllowOnlyTokensAfterIndexLogitsProcessor, | |
| AllowOnlyTokensLogitsProcessor, | |
| InBatchInstructCFGLogitsProcessor, | |
| ) | |
| from chameleon.inference.model_adapter import ChameleonModelAdapter | |
| from chameleon.inference.stopping_criteria import StopOnEOS, StopOnEOSAfterBatchIndex | |
| from chameleon.inference.token_selector import ( | |
| MultinomialTokenSelector, | |
| ReplicatedInputTokenSelector, | |
| ) | |
| from chameleon.inference.vocab import VocabInfo, VocabTranslation | |
| from chameleon.viewer.backend.models.abstract_model import ( | |
| DEFAULT_IMAGE_CFG_IMAGE, | |
| DEFAULT_IMAGE_CFG_TEXT, | |
| DEFAULT_MULTIMODAL_CFG_IMAGE, | |
| DEFAULT_MULTIMODAL_CFG_TEXT, | |
| AbstractMultimodalGenerator, | |
| MixedSequenceType, | |
| StreamingImage, | |
| ) | |
| from chameleon.viewer.backend.utils import get_logger | |
| logger = get_logger(__name__) | |
| def set_seed(seed: int) -> None: | |
| transformers.enable_full_determinism(seed, warn_only=True) | |
| def get_rank() -> int: | |
| if torch.distributed.is_initialized(): | |
| return torch.distributed.get_rank() | |
| else: | |
| return 0 | |
| class ChameleonTokenizationMixin: | |
| def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes: | |
| img = self.pillow_from_bpe_tokens(bpe_tokens) | |
| img_io = io.BytesIO() | |
| img.save(img_io, format="PNG") | |
| return img_io.getvalue() | |
| def pillow_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image.Image: | |
| image_tensor = VocabTranslation(self.vocab).convert_bpe2img(bpe_tokens) | |
| if image_tensor.shape[0] < 1024: | |
| padding = ( | |
| torch.ones([1024 - image_tensor.shape[0]], dtype=int) * image_tensor[0] | |
| ) | |
| image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0) | |
| return self.image_tokenizer.pil_from_img_toks(image_tensor) | |
| def tokens_from_inputs( | |
| self, | |
| inputs: MixedSequenceType, | |
| suffix_tokens: list[str] | None = None, | |
| ) -> list[int]: | |
| tokens = [self.vocab.bos_id] | |
| for input_ in inputs: | |
| if isinstance(input_, str): | |
| tokens.extend(self.tokenizer.encode(input_.strip()).ids) | |
| elif isinstance(input_, PIL.Image.Image): | |
| tokens.append(self.vocab.begin_image) | |
| imgtoks = self.image_tokenizer.img_tokens_from_pil(input_) | |
| tokens.extend(VocabTranslation(self.vocab).convert_img2bp2(imgtoks)) | |
| tokens.append(self.vocab.end_image) | |
| else: | |
| raise ValueError(f"Unknown input type: {type(input_)}") | |
| if suffix_tokens is not None: | |
| for t in suffix_tokens: | |
| tokens.extend(self.tokenizer.encode(t).ids) | |
| sanitized_tokens = [] | |
| for t in tokens: | |
| if isinstance(t, torch.Tensor): | |
| sanitized_tokens.append(t.item()) | |
| else: | |
| sanitized_tokens.append(t) | |
| return sanitized_tokens | |
| class GeneratorWrapper: | |
| def __init__(self, gen): | |
| self.gen = gen | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| return next(self.gen) | |
| class Decoder: | |
| def __init__( | |
| self, | |
| chameleon_generator: "ChameleonLocalGenerator", | |
| input_ids: list[int], | |
| ): | |
| ... | |
| def __next__(self) -> tuple[list[int], dict | None, type["Decoder"] | None]: | |
| ... | |
| class TextDecoder(Decoder): | |
| def __init__( | |
| self, | |
| chameleon_generator: "ChameleonLocalGenerator", | |
| input_ids: list[int], | |
| *, | |
| temp: float, | |
| top_p: float, | |
| max_seq_len: int, | |
| # TODO: Propagage setting upwards | |
| repetition_penalty: float, | |
| **kwargs, | |
| ): | |
| self.chameleon_generator = chameleon_generator | |
| assert chameleon_generator.vocab.eos_id is not None | |
| stopping_criteria = [ | |
| StopOnEOS(chameleon_generator.vocab.eos_id), | |
| MaxLengthCriteria(max_seq_len), | |
| ] | |
| if chameleon_generator.additional_eos_tokens is not None: | |
| for token in chameleon_generator.additional_eos_tokens: | |
| stopping_criteria.append( | |
| StopOnEOSAfterBatchIndex( | |
| chameleon_generator.tokenizer.token_to_id(token), [len(input_ids)] | |
| ) | |
| ) | |
| logits_processors = [ | |
| AllowOnlyTokensLogitsProcessor( | |
| chameleon_generator.vocab.text_tokens | |
| + [chameleon_generator.vocab.eos_id, chameleon_generator.vocab.begin_image] | |
| ), | |
| # Don't allow any more images near the end since there isn't enough room | |
| AllowOnlyTokensAfterIndexLogitsProcessor( | |
| chameleon_generator.vocab.text_tokens + [chameleon_generator.vocab.eos_id], | |
| # TODO: Calculate exact | |
| 1024 * 3 - 3, | |
| ), | |
| RepetitionPenaltyLogitsProcessor(repetition_penalty), | |
| TemperatureLogitsWarper(temp), | |
| TopPLogitsWarper(top_p), | |
| ] | |
| self.gen = ChameleonGenerator( | |
| model=ChameleonModelAdapter(chameleon_generator.model, max_seq_len=max_seq_len), | |
| input_ids=[input_ids], | |
| stopping_criteria=stopping_criteria, | |
| logits_processors=logits_processors, | |
| ) | |
| for _ in range(len(input_ids)): | |
| next(self.gen) | |
| def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]: | |
| gpu_tok = next(self.gen).id.item() | |
| cpu_tok = gpu_tok | |
| if cpu_tok == self.chameleon_generator.vocab.begin_image: | |
| # return "TEXT", [cpu_tok], [], False, ImageDecoder | |
| raise StopIteration() | |
| return ( | |
| "TEXT", | |
| [cpu_tok], | |
| [cpu_tok], | |
| False, | |
| None, | |
| ) | |
| class ImageDecoder(Decoder): | |
| def __init__( | |
| self, | |
| chameleon_generator: "ChameleonLocalGenerator", | |
| input_ids: list[int], | |
| *, | |
| cfg_image_weight: float, | |
| cfg_text_weight: float, | |
| temp: float, | |
| top_p: float, | |
| yield_every_n: int, | |
| **kwargs, | |
| ): | |
| self.yield_every_n = yield_every_n | |
| self.chameleon_generator = chameleon_generator | |
| logits_processors = [ | |
| InBatchInstructCFGLogitsProcessor(cfg_text_weight, cfg_image_weight), | |
| AllowOnlyTokensLogitsProcessor(chameleon_generator.vocab.image_tokens), | |
| TemperatureLogitsWarper(temp), | |
| TopPLogitsWarper(top_p), | |
| ] | |
| image_conditioned_allowed = set(chameleon_generator.vocab.image_tokens) | { | |
| chameleon_generator.vocab.bos_id, | |
| chameleon_generator.vocab.begin_image, | |
| chameleon_generator.vocab.end_image, | |
| } | |
| full_conditioned = input_ids | |
| image_conditioned = [ | |
| in_id for in_id in input_ids if in_id in image_conditioned_allowed | |
| ] | |
| unconditioned = [ | |
| chameleon_generator.vocab.bos_id, | |
| chameleon_generator.vocab.begin_image, | |
| ] | |
| self.gen = ChameleonGenerator( | |
| model=ChameleonModelAdapter( | |
| chameleon_generator.model, max_seq_len=len(input_ids) + 1024 | |
| ), | |
| input_ids=[full_conditioned, image_conditioned, unconditioned], | |
| logits_processors=logits_processors, | |
| alignment=AlignPromptRight(chameleon_generator.vocab.pad_id), | |
| token_selector=ReplicatedInputTokenSelector( | |
| MultinomialTokenSelector(), n=3 | |
| ), | |
| ) | |
| for _ in range(len(input_ids)): | |
| next(self.gen) | |
| self.image_builder: list[torch.LongTensor] = [] | |
| self.gpu_tok_batch: list[torch.LongTensor] = [] | |
| def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]: | |
| while True: | |
| gpu_tok = next(self.gen) | |
| gpu_tok = torch.chunk(gpu_tok, chunks=3, dim=0)[0] | |
| self.image_builder.append(gpu_tok) | |
| self.gpu_tok_batch.append(gpu_tok) | |
| if len(self.image_builder) == 1024: | |
| return ( | |
| "IMAGE", | |
| torch.tensor(self.gpu_tok_batch).tolist() | |
| + [self.chameleon_generator.vocab.end_image], | |
| torch.tensor(self.image_builder).tolist(), | |
| True, | |
| TextDecoder, | |
| ) | |
| elif len(self.image_builder) % self.yield_every_n == 0: | |
| cpu_toks = torch.tensor(self.gpu_tok_batch).tolist() | |
| self.gpu_tok_batch = [] | |
| return ( | |
| "IMAGE", | |
| cpu_toks, | |
| torch.tensor(self.image_builder).tolist(), | |
| False, | |
| None, | |
| ) | |
| class ChameleonForwardMixin: | |
| def _generate_text_streaming( | |
| self, | |
| input_ids: list[int], | |
| max_gen_tokens: int = 256, | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| repetition_penalty: float = 1.2, | |
| seed: int | None = None, | |
| ) -> Generator[str, None, None]: | |
| if seed is not None: | |
| set_seed(seed) | |
| logger.info( | |
| "Rank: %s, set seed: %s", | |
| get_rank(), | |
| seed, | |
| ) | |
| logits_processors = [ | |
| # Only allow text tokens and end-of-sequence. | |
| AllowOnlyTokensLogitsProcessor( | |
| self.vocab.text_tokens + [self.vocab.eos_id] | |
| ), | |
| # Don't allow the first token to be end-of-sequence. | |
| # DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()), | |
| RepetitionPenaltyLogitsProcessor(repetition_penalty), | |
| TemperatureLogitsWarper(temp), | |
| TopPLogitsWarper(top_p), | |
| ] | |
| stopping_criteria = [ | |
| StopOnEOS(self.vocab.eos_id), | |
| MaxLengthCriteria(len(input_ids) + max_gen_tokens), | |
| ] | |
| if self.additional_eos_tokens is not None: | |
| for token in self.additional_eos_tokens: | |
| stopping_criteria.append( | |
| StopOnEOSAfterBatchIndex( | |
| self.tokenizer.token_to_id(token), [len(input_ids)] | |
| ) | |
| ) | |
| for tok in ChameleonGenerator( | |
| model=ChameleonModelAdapter( | |
| self.model, | |
| max_seq_len=len(input_ids) + max_gen_tokens, | |
| ), | |
| input_ids=[input_ids], | |
| stopping_criteria=stopping_criteria, | |
| logits_processors=logits_processors, | |
| ): | |
| yield tok.tolist() | |
| def _generate_batched_text_streaming( | |
| self, | |
| batch: list[list[int]], | |
| max_gen_tokens: int = 256, | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| repetition_penalty: float = 1.2, | |
| seed: int | None = None, | |
| ) -> Generator[list[str], None, None]: | |
| if seed is not None: | |
| set_seed(seed) | |
| logits_processors = [ | |
| # Only allow text tokens and end-of-sequence. | |
| AllowOnlyTokensLogitsProcessor( | |
| self.vocab.text_tokens + [self.vocab.eos_id] | |
| ), | |
| # Don't allow the first token to be end-of-sequence. | |
| # DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()), | |
| RepetitionPenaltyLogitsProcessor(repetition_penalty), | |
| TemperatureLogitsWarper(temp), | |
| TopPLogitsWarper(top_p), | |
| ] | |
| max_batch_size = max(len(p) for p in batch) | |
| stopping_criteria = [ | |
| StopOnEOS(self.vocab.eos_id), | |
| MaxLengthCriteria(max_batch_size + max_gen_tokens), | |
| ] | |
| if self.additional_eos_tokens is not None: | |
| for token in self.additional_eos_tokens: | |
| stopping_criteria.append( | |
| StopOnEOSAfterBatchIndex( | |
| self.tokenizer.token_to_id(token), [len(x) for x in batch] | |
| ) | |
| ) | |
| for tok in ChameleonGenerator( | |
| model=ChameleonModelAdapter( | |
| self.model, | |
| max_seq_len=max_batch_size + max_gen_tokens, | |
| ), | |
| input_ids=batch, | |
| stopping_criteria=stopping_criteria, | |
| logits_processors=logits_processors, | |
| ): | |
| yield tok.unsqueeze(1).tolist() | |
| def _generate_image_streaming( | |
| self, | |
| tokenized_prompt: list[int], | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE, | |
| cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT, | |
| yield_every_n: int = 32, | |
| seed: int | None = None, | |
| ) -> Generator[tuple[list[int], bool], None, None]: | |
| if seed is not None: | |
| set_seed(seed) | |
| logger.info( | |
| "Rank: %s, set seed: %s", | |
| get_rank(), | |
| seed, | |
| ) | |
| decoder = ImageDecoder( | |
| self, | |
| tokenized_prompt, | |
| cfg_image_weight=cfg_image_weight, | |
| cfg_text_weight=cfg_text_weight, | |
| temp=temp, | |
| top_p=top_p, | |
| yield_every_n=yield_every_n, | |
| ) | |
| for _, _, frontend_tokens, is_final, next_decoder in GeneratorWrapper(decoder): | |
| if next_decoder is not None: | |
| break | |
| yield torch.tensor(frontend_tokens).tolist(), is_final | |
| def _generate_multimodal_streaming( | |
| self, | |
| input_ids: list[int], | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE, | |
| cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT, | |
| yield_every_n: int = 32, | |
| max_gen_tokens: int = 4096, | |
| repetition_penalty: float = 1.2, | |
| seed: int | None = None, | |
| ) -> Generator[tuple[str, list[int], bool], None, None]: | |
| if seed is not None: | |
| set_seed(seed) | |
| logger.info( | |
| "Rank: %s, set seed: %s", | |
| get_rank(), | |
| seed, | |
| ) | |
| max_seq_len = min(len(input_ids) + max_gen_tokens, 4096) | |
| gen_wrapper = GeneratorWrapper( | |
| TextDecoder( | |
| self, | |
| input_ids, | |
| temp=temp, | |
| top_p=top_p, | |
| max_seq_len=max_seq_len, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| ) | |
| for ( | |
| message_type, | |
| cpu_toks, | |
| frontend_tokens, | |
| is_final, | |
| next_decoder, | |
| ) in gen_wrapper: | |
| input_ids.extend(cpu_toks) | |
| if len(frontend_tokens) > 0: | |
| yield message_type, frontend_tokens, is_final | |
| if next_decoder is not None: | |
| gen_wrapper.gen = next_decoder( | |
| self, | |
| input_ids, | |
| temp=temp, | |
| top_p=top_p, | |
| max_seq_len=max_seq_len, | |
| cfg_image_weight=cfg_image_weight, | |
| cfg_text_weight=cfg_text_weight, | |
| yield_every_n=yield_every_n, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| class ChameleonLocalGenerator( | |
| AbstractMultimodalGenerator, ChameleonForwardMixin, ChameleonTokenizationMixin | |
| ): | |
| def __init__( | |
| self, | |
| model_path: str, | |
| tokenizer_path: str, | |
| vqgan_config_path: str, | |
| vqgan_ckpt_path: str | None = None, | |
| additional_eos_tokens: list[str] | None = None, | |
| ) -> None: | |
| super().__init__() | |
| logger.info("Loading model...") | |
| self.model = load_model(model_path) | |
| self.additional_eos_tokens = additional_eos_tokens | |
| logger.info("Loading tokenizer...") | |
| tokenizer_path = tokenizer_path | |
| self.tokenizer = Tokenizer.from_file(str(tokenizer_path)) | |
| self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"]) | |
| logger.info("Loading VQGAN...") | |
| self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path) | |
| def generate_batched_text( | |
| self, | |
| prompts: list[MixedSequenceType], | |
| max_gen_tokens: int = 256, | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| repetition_penalty: float = 1.2, | |
| seed: int | None = None, | |
| ) -> list[str]: | |
| outputs = [""] * len(prompts) | |
| for vals in self.generate_batched_text_streaming( | |
| prompts, | |
| max_gen_tokens=max_gen_tokens, | |
| temp=temp, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| seed=seed, | |
| ): | |
| for idx, val in enumerate(vals): | |
| outputs[idx] += val | |
| return outputs | |
| def generate_batched_text_streaming( | |
| self, | |
| prompts: list[MixedSequenceType], | |
| max_gen_tokens: int = 256, | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| repetition_penalty: float = 1.2, | |
| seed: int | None = None, | |
| ) -> Generator[list[str], None, None]: | |
| batch = [] | |
| for prompt in prompts: | |
| batch.append(self.tokens_from_inputs(prompt)) | |
| for tok in self._generate_batched_text_streaming( | |
| batch, | |
| max_gen_tokens=max_gen_tokens, | |
| temp=temp, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| seed=seed, | |
| ): | |
| yield self.tokenizer.decode_batch(tok) | |
| async def generate_text_streaming( | |
| self, | |
| prompt: MixedSequenceType, | |
| max_gen_tokens: int = 256, | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| repetition_penalty: float = 1.2, | |
| seed: int | None = None, | |
| debug: dict | None = None, | |
| ) -> Generator[str, None, None]: | |
| tokenized_prompt = self.tokens_from_inputs(prompt) | |
| if len(tokenized_prompt) > (4096 - 3): | |
| yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." | |
| return | |
| for out in self.generate_batched_text_streaming( | |
| [prompt], | |
| max_gen_tokens=max_gen_tokens, | |
| temp=temp, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| seed=seed, | |
| ): | |
| yield out[0] | |
| async def generate_image_streaming( | |
| self, | |
| prompt: MixedSequenceType, | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE, | |
| cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT, | |
| yield_every_n: int = 32, | |
| seed: int | None = None, | |
| debug: dict | None = None, | |
| ) -> Generator[StreamingImage, None, None]: | |
| assert isinstance(prompt, list) | |
| tokenized_prompt = self.tokens_from_inputs(prompt) | |
| tokenized_prompt.append(self.vocab.begin_image) | |
| if len(tokenized_prompt) > (4096 - 3 - 1024): | |
| yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output." | |
| return | |
| for tokens, final in self._generate_image_streaming( | |
| tokenized_prompt, | |
| temp=temp, | |
| top_p=top_p, | |
| cfg_image_weight=cfg_image_weight, | |
| cfg_text_weight=cfg_text_weight, | |
| yield_every_n=yield_every_n, | |
| seed=seed, | |
| ): | |
| yield StreamingImage( | |
| image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final | |
| ) | |
| async def generate_multimodal_streaming( | |
| self, | |
| prompt: MixedSequenceType, | |
| temp: float = 1.0, | |
| top_p: float = 0.8, | |
| cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE, | |
| cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT, | |
| yield_every_n: int = 32, | |
| max_gen_tokens: int = 4096, | |
| repetition_penalty: float = 1.2, | |
| suffix_tokens: list[str] | None = None, | |
| seed: int | None = None, | |
| debug: dict | None = None, | |
| ) -> Generator[MixedSequenceType, None, None]: | |
| input_ids = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens) | |
| if len(input_ids) > (4096 - 3): | |
| yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens." | |
| return | |
| for token_type, tokens, is_final in self._generate_multimodal_streaming( | |
| input_ids, | |
| temp=temp, | |
| top_p=top_p, | |
| cfg_image_weight=cfg_image_weight, | |
| cfg_text_weight=cfg_text_weight, | |
| yield_every_n=yield_every_n, | |
| max_gen_tokens=max_gen_tokens, | |
| repetition_penalty=repetition_penalty, | |
| seed=seed, | |
| ): | |
| match token_type: | |
| case "TEXT": | |
| yield self.tokenizer.decode(tokens) | |
| case "IMAGE": | |
| yield StreamingImage( | |
| image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), | |
| final=is_final, | |
| ) | |
| case _: | |
| raise ValueError("Unknown token type") | |