OCR / app.py
namanjais's picture
Update app.py
a6aeb63 verified
import os
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
from groq import Groq
def enhance_medical_text(text):
try:
client = Groq(
api_key=os.getenv("GROQ_API_KEY") # Get API key from environment variable
)
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a medical prescription expert. Correct OCR errors in medicine names, dosages and medical terms..."
},
{
"role": "user",
"content": f"Correct this medical prescription OCR output:\n{text}"
}
],
model="llama3-8b-8192",
temperature=0.1,
max_tokens=1024
)
return chat_completion.choices[0].message.content
except Exception as e:
print(f"Groq enhancement error: {str(e)}")
return text
# Get Hugging Face token securely from environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load model and processor with token
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
token=HF_TOKEN,
torch_dtype=torch_dtype,
trust_remote_code=True
).to(device)
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-large",
token=HF_TOKEN,
trust_remote_code=True
)
def run_ocr(image, task_prompt="<OCR>"):
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
import gradio as gr
def process_single_image(image):
image = Image.fromarray(image)
result = run_ocr(image, "<OCR>")
corrected_text = enhance_medical_text(result)
return result, corrected_text
if __name__ == "__main__":
demo = gr.Interface(
fn=process_single_image,
inputs=gr.Image(label="Upload Prescription"),
outputs=[
gr.Textbox(label="Raw OCR Output"),
gr.Textbox(label="Enhanced Medical Report")
],
title="Medical Prescription OCR",
description="Upload a medical prescription image for OCR processing and enhancement"
)
demo.launch(server_name="0.0.0.0", server_port=7860)