Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
| from PIL import Image | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16) | |
| if device == "cuda": | |
| model_deplot = model_deplot.to(0) | |
| processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot") | |
| def add_markup(table): | |
| try: | |
| parts = [p.strip() for p in table.splitlines(keepends=False)] | |
| if parts[0].startswith('TITLE'): | |
| result = f"Title: {parts[0].split(' | ')[1].strip()}\n" | |
| rows = parts[1:] | |
| else: | |
| result = '' | |
| rows = parts | |
| prefixes = ['Header: '] + [f'Row {i+1}: ' for i in range(len(rows) - 1)] | |
| return result + '\n'.join(prefix + row for prefix, row in zip(prefixes, rows)) | |
| except: | |
| # just use the raw table if parsing fails | |
| return table | |
| def process_image(image): | |
| inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", | |
| return_tensors="pt").to(torch.bfloat16) | |
| if device == "cuda": | |
| inputs = inputs.to(0) | |
| predictions = model_deplot.generate(**inputs, max_new_tokens=512) | |
| table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n") | |
| return table | |
| if __name__ == "__main__": | |
| im = Image.open(r"meat-image.png") | |
| process_image(im) | |