| from typing import Any | |
| import torch | |
| class VisualGLMBasePostProcessor: | |
| """Base post processor for VisualGLM.""" | |
| def __init__(self) -> None: | |
| pass | |
| def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str: | |
| return tokenizer.decode(output_token) | |
| class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor): | |
| """VSR post processor for VisualGLM.""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str: | |
| output_text = tokenizer.decode(output_token) | |
| if 'yes' in output_text.lower(): | |
| return 'yes' | |
| elif 'no' in output_text.lower(): | |
| return 'no' | |
| else: | |
| return 'unknown' | |