Spaces:
Runtime error
Runtime error
| from utilities import ( | |
| PDFExtractor, | |
| get_simple_logger, | |
| ModuleException, | |
| ) | |
| from create_dataset import clean_text | |
| from model import PDFDataSet | |
| import pickle | |
| import argparse | |
| class Inference: | |
| """A class that does inference given a url of the pdf. Uses the saved RandomForest model for inference.""" | |
| def __init__(self, pdf_url, model_path="final_model.pkl") -> None: | |
| self.pdf_url = pdf_url | |
| self.model_path = model_path | |
| self.logger = get_simple_logger("inference", level="debug") | |
| def _extract_text(self): | |
| self.logger.debug("Extracting text from the PDF.") | |
| pdf_extractor = PDFExtractor( | |
| self.pdf_url, | |
| min_characters=5, | |
| maximum_pages=3, | |
| ) | |
| try: | |
| final = pdf_extractor.extract_pages() | |
| data = { | |
| "status": "ok", | |
| "id": "001", | |
| "page_contents": pdf_extractor.page_contents, | |
| "final_content": clean_text(final), | |
| "url": self.pdf_url, | |
| } | |
| except ModuleException as e: | |
| self.logger.warning("A module exception is raised.") | |
| self.logger.error(e) | |
| raise e | |
| except Exception as e: | |
| self.logger.error(e) | |
| raise e | |
| return data, final | |
| def _load_model(self): | |
| with open(self.model_path, "rb") as f: | |
| model = pickle.load(f) | |
| self.model = model | |
| return self.model | |
| def _create_embedding(self, text): | |
| dataset = PDFDataSet() | |
| embedding = dataset.sentences_to_embedding([text]) | |
| return embedding.reshape(1, -1) | |
| def _pretty_print_probability(self, p): | |
| non_lighting_probability = p[0] * 100 | |
| lighting_probablity = p[1] * 100 | |
| if lighting_probablity > non_lighting_probability: | |
| print( | |
| f"This is a Lighting product with a probability of {lighting_probablity:.2f}%" | |
| ) | |
| else: | |
| print( | |
| f"This is a Non-Lighting product with a probability of {non_lighting_probability:.2f}%" | |
| ) | |
| def predict(self): | |
| _, sentence = self._extract_text() | |
| embedding = self._create_embedding(sentence) | |
| model = self._load_model() | |
| prediction = model.predict_proba(embedding) | |
| prediction = prediction[0] | |
| self._pretty_print_probability(prediction) | |
| return prediction | |
| def main(args): | |
| inference = Inference(args.url) | |
| prediction = inference.predict() | |
| print(prediction) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--url", type=str, required=True) | |
| args = parser.parse_args() | |
| main(args) | |