Spaces:
Runtime error
Runtime error
| from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
| import gradio as gr | |
| from PIL import Image | |
| # Use a public model identifier. If you need a private model, remember to authenticate. | |
| model_name = "google/pix2struct-textcaps-base" | |
| model = Pix2StructForConditionalGeneration.from_pretrained(model_name) | |
| processor = Pix2StructProcessor.from_pretrained(model_name) | |
| def solve_math_problem(image): | |
| try: | |
| # Ensure the image is in RGB format. | |
| image = image.convert("RGB") | |
| # Preprocess the image and text. Note that header_text is omitted as it's not used for non-VQA tasks. | |
| inputs = processor( | |
| images=[image], | |
| text="Solve the following math problem:", | |
| return_tensors="pt", | |
| max_patches=2048 | |
| ) | |
| # Generate the solution with generation parameters. | |
| predictions = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| early_stopping=True, | |
| num_beams=4, | |
| temperature=0.2 | |
| ) | |
| # Decode the problem text and generated solution. | |
| problem_text = processor.decode( | |
| inputs["input_ids"][0], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| solution = processor.decode( | |
| predictions[0], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| return f"Problem: {problem_text}\nSolution: {solution}" | |
| except Exception as e: | |
| return f"Error processing image: {str(e)}" | |
| # Set up the Gradio interface. | |
| demo = gr.Interface( | |
| fn=solve_math_problem, | |
| inputs=gr.Image( | |
| type="pil", | |
| label="Upload Handwritten Math Problem", | |
| image_mode="RGB" # This forces the input to be RGB. | |
| ), | |
| outputs=gr.Textbox(label="Solution", show_copy_button=True), | |
| title="Handwritten Math Problem Solver", | |
| description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution", | |
| examples=[ | |
| ["example_addition.png"], | |
| ["example_algebra.jpg"] | |
| ], | |
| theme="soft", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |