RP-Azul commited on
Commit
9eca730
·
verified ·
1 Parent(s): e8951e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -26,16 +26,16 @@ def extract_docx_text(uploaded_file):
26
  # docx2txt.process accepts a path or a file-like object
27
  return docx2txt.process(uploaded_file)
28
  # --- Image model setup ---
29
- MODEL_NAME = "google/vit-base-patch16-224"
30
 
31
- @st.cache_resource
32
- def load_image_model():
33
- proc = AutoProcessor.from_pretrained(MODEL_NAME)
34
- mdl = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
35
- return proc, mdl
36
 
37
- processor, model = load_image_model()
 
 
 
 
38
 
 
39
  # --- Main UI ---
40
  input_type = st.selectbox(
41
  "Select the type of input:",
@@ -68,23 +68,23 @@ elif input_type == "Text":
68
  st.text_area("Content", notes, height=300)
69
 
70
  elif input_type == "Image":
71
- uploaded_img = st.file_uploader("Upload a PNG image", type=["png"])
72
  if uploaded_img is not None:
73
  img = Image.open(uploaded_img).convert("RGB")
74
  st.image(img, caption="🖼️ Uploaded Image", use_column_width=True)
75
 
76
- # preprocess & inference
77
- inputs = processor(images=img, return_tensors="pt")
 
 
78
  with torch.no_grad():
79
- outputs = model(**inputs)
80
- probs = torch.softmax(outputs.logits, dim=-1)[0]
81
- top5 = torch.topk(probs, k=5)
82
 
83
- st.subheader("🔍 Top 5 Predictions")
84
- for idx, score in zip(top5.indices.tolist(), top5.values.tolist()):
85
- label = model.config.id2label[idx]
86
- st.write(f"- **{label}**: {score*100:.1f}%")
87
 
88
  else:
89
  st.info("Please select an input type to get started.")
90
 
 
 
26
  # docx2txt.process accepts a path or a file-like object
27
  return docx2txt.process(uploaded_file)
28
  # --- Image model setup ---
29
+ OCR_MODEL = "microsoft/trocr-base-printed"
30
 
 
 
 
 
 
31
 
32
+ @st.cache_resource
33
+ def load_ocr_model():
34
+ processor = TrOCRProcessor.from_pretrained(OCR_MODEL)
35
+ model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL)
36
+ return processor, model
37
 
38
+ ocr_processor, ocr_model = load_ocr_model()
39
  # --- Main UI ---
40
  input_type = st.selectbox(
41
  "Select the type of input:",
 
68
  st.text_area("Content", notes, height=300)
69
 
70
  elif input_type == "Image":
71
+ uploaded_img = st.file_uploader("Upload a PNG/JPG image", type=["png", "jpg", "jpeg"])
72
  if uploaded_img is not None:
73
  img = Image.open(uploaded_img).convert("RGB")
74
  st.image(img, caption="🖼️ Uploaded Image", use_column_width=True)
75
 
76
+ # 1. Preprocess for OCR
77
+ pixel_values = ocr_processor(images=img, return_tensors="pt").pixel_values
78
+
79
+ # 2. Generate and decode
80
  with torch.no_grad():
81
+ generated_ids = ocr_model.generate(pixel_values)
82
+ extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
83
 
84
+ st.subheader("🖋️ Extracted Text from Image")
85
+ st.text_area("OCR Result", extracted_text, height=300)
 
 
86
 
87
  else:
88
  st.info("Please select an input type to get started.")
89
 
90
+