| | import os |
| | import torch |
| | from transformers import AutoProcessor, AutoModelForCausalLM |
| | from PIL import Image |
| | from groq import Groq |
| |
|
| | def enhance_medical_text(text): |
| | try: |
| | client = Groq( |
| | api_key=os.getenv("GROQ_API_KEY") |
| | ) |
| |
|
| | chat_completion = client.chat.completions.create( |
| | messages=[ |
| | { |
| | "role": "system", |
| | "content": "You are a medical prescription expert. Correct OCR errors in medicine names, dosages and medical terms..." |
| | }, |
| | { |
| | "role": "user", |
| | "content": f"Correct this medical prescription OCR output:\n{text}" |
| | } |
| | ], |
| | model="llama3-8b-8192", |
| | temperature=0.1, |
| | max_tokens=1024 |
| | ) |
| | return chat_completion.choices[0].message.content |
| | except Exception as e: |
| | print(f"Groq enhancement error: {str(e)}") |
| | return text |
| |
|
| | |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| | torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| |
|
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | "microsoft/Florence-2-large", |
| | token=HF_TOKEN, |
| | torch_dtype=torch_dtype, |
| | trust_remote_code=True |
| | ).to(device) |
| |
|
| | processor = AutoProcessor.from_pretrained( |
| | "microsoft/Florence-2-large", |
| | token=HF_TOKEN, |
| | trust_remote_code=True |
| | ) |
| |
|
| | def run_ocr(image, task_prompt="<OCR>"): |
| | inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(device, torch_dtype) |
| | generated_ids = model.generate( |
| | input_ids=inputs["input_ids"], |
| | pixel_values=inputs["pixel_values"], |
| | max_new_tokens=1024, |
| | num_beams=3, |
| | do_sample=False |
| | ) |
| | return processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| |
|
| | import gradio as gr |
| |
|
| | def process_single_image(image): |
| | image = Image.fromarray(image) |
| | result = run_ocr(image, "<OCR>") |
| | corrected_text = enhance_medical_text(result) |
| | return result, corrected_text |
| |
|
| | if __name__ == "__main__": |
| | demo = gr.Interface( |
| | fn=process_single_image, |
| | inputs=gr.Image(label="Upload Prescription"), |
| | outputs=[ |
| | gr.Textbox(label="Raw OCR Output"), |
| | gr.Textbox(label="Enhanced Medical Report") |
| | ], |
| | title="Medical Prescription OCR", |
| | description="Upload a medical prescription image for OCR processing and enhancement" |
| | ) |
| | demo.launch(server_name="0.0.0.0", server_port=7860) |