Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| # Load the OCR model and processor | |
| ocr_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen2-VL-7B-Instruct", | |
| torch_dtype="auto", | |
| device_map="auto", | |
| ) | |
| ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
| # Load the Math model and tokenizer | |
| math_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen2.5-Math-72B-Instruct", | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-72B-Instruct") | |
| # OCR extraction function | |
| def ocr_and_query(image, question): | |
| # Prepare image for the model | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| { | |
| "type": "text", | |
| "text": question | |
| }, | |
| ], | |
| } | |
| ] | |
| # Process image and text prompt | |
| text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt") | |
| # Run the model to generate OCR results | |
| inputs = inputs.to("cuda") | |
| output_ids = ocr_model.generate(**inputs, max_new_tokens=1024) | |
| # Decode the generated text | |
| generated_ids = [ | |
| output_ids[len(input_ids):] | |
| for input_ids, output_ids in zip(inputs.input_ids, output_ids) | |
| ] | |
| output_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] | |
| return output_text | |
| # Math problem solving function | |
| def solve_math_problem(prompt): | |
| # CoT (Chain of Thought) | |
| messages = [ | |
| {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| text = math_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = math_tokenizer([text], return_tensors="pt").to("cuda") | |
| generated_ids = math_model.generate( | |
| **model_inputs, | |
| max_new_tokens=512 | |
| ) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = math_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return response | |
| # Function to clear inputs and output | |
| def clear_inputs(): | |
| return None, "", "" | |
| # Gradio interface setup | |
| def gradio_app(image, question, task): | |
| if task == "OCR and Query": | |
| return image, question, ocr_and_query(image, question) | |
| elif task == "Solve Math Problem from Image": | |
| if image is None: | |
| return image, question, "Please upload an image." | |
| extracted_text = ocr_and_query(image, "") | |
| math_solution = solve_math_problem(extracted_text) | |
| return image, extracted_text, math_solution | |
| elif task == "Solve Math Problem from Text": | |
| if question.strip() == "": | |
| return image, question, "Please enter a math problem." | |
| math_solution = solve_math_problem(question) | |
| return image, question, math_solution | |
| else: | |
| return image, question, "Please select a task." | |
| # Gradio interface | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Image OCR and Math Solver") | |
| gr.Markdown("Upload an image, enter your question or math problem, and select the appropriate task.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox(lines=2, placeholder="Enter your question or math problem here...", label="Input") | |
| with gr.Row(): | |
| task_radio = gr.Radio(["OCR and Query", "Solve Math Problem from Image", "Solve Math Problem from Text"], label="Task") | |
| with gr.Row(): | |
| complete_button = gr.Button("Complete") | |
| clear_button = gr.Button("Clear") | |
| output = gr.Markdown(label="Output") | |
| # Event listeners | |
| complete_button.click(fn=gradio_app, inputs=[image_input, text_input, task_radio], outputs=[image_input, text_input, output]) | |
| clear_button.click(fn=clear_inputs, outputs=[image_input, text_input, output]) | |
| # Launch the app | |
| app.launch(share=True) | |