Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| import torch | |
| from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
| from pathlib import Path | |
| def make_descriptions(file, title): | |
| if Path(file).suffix == '.csv': | |
| # print(file) | |
| df = pd.read_csv(file) | |
| print(df.head()) | |
| columns = list(df.columns) | |
| print(columns) | |
| table_description0 = { | |
| 'path': 'random', | |
| 'number': 1, | |
| 'columns': ["clothes", "animals", "students"], | |
| 'title': "fashionable student clothes" | |
| } | |
| table_description1 = { | |
| 'path': file, | |
| 'number': 2, | |
| 'columns': columns, | |
| 'title': title | |
| } | |
| table_descriptions = [table_description0, table_description1] | |
| return table_descriptions | |
| else: | |
| file_description = { | |
| 'path': file, | |
| 'number': 1, | |
| 'title': title | |
| } | |
| file_descriptions = [file_description] | |
| return file_descriptions | |
| def make_documents(pdf): | |
| loader = PyPDFLoader(pdf) | |
| documents = loader.load() | |
| text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0, separator='\n') | |
| documents = text_splitter.split_documents(documents) | |
| return documents | |
| class Matcha_model: | |
| def __init__(self) -> None: | |
| # torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/20294671002019.png', 'chart_example.png') | |
| # torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/multi_col_1081.png', 'chart_example_2.png') | |
| # torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/18143564004789.png', 'chart_example_3.png') | |
| # torch.hub.download_url_to_file('https://sharkcoder.com/files/article/matplotlib-bar-plot.png', 'chart_example_4.png') | |
| self.model_name = "google/matcha-chartqa" | |
| self.model = Pix2StructForConditionalGeneration.from_pretrained(self.model_name) | |
| self.processor = Pix2StructProcessor.from_pretrained(self.model_name) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model.to(self.device) | |
| def _filter_output(self, output): | |
| return output.replace("<0x0A>", "") | |
| def chart_qa(self, image, question: str) -> str: | |
| inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device) | |
| predictions = self.model.generate(**inputs, max_new_tokens=512) | |
| return self._filter_output(self.processor.decode(predictions[0], skip_special_tokens=True)) | |