mkoot007 commited on
Commit
79be51e
·
1 Parent(s): ce2f606

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -16
app.py CHANGED
@@ -1,38 +1,26 @@
1
  import streamlit as st
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  from easyocr import Reader
5
 
6
- # Load the OCR model and text explanation model (gpt-2 as an example)
7
  ocr_reader = Reader(['en'])
8
  explainer = AutoModelForSequenceClassification.from_pretrained("gpt2")
9
-
10
- # Define a function to extract text from an image
11
  def extract_text(image):
12
  return ocr_reader.readtext(image)
13
-
14
- # Define a function to explain the extracted text
15
  def explain_text(text):
16
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
17
  input_ids = tokenizer.encode(text, return_tensors="pt")
18
  explanation = explainer(input_ids)
19
  return explanation
20
-
21
- # Create a Streamlit layout
22
  st.title("Text Classification Model")
23
-
24
- # Allow users to upload an image
25
  uploaded_file = st.file_uploader("Upload an image:")
26
-
27
- # Extract text from the uploaded image
28
  if uploaded_file is not None:
29
- image = torch.from_numpy(uploaded_file.read()).unsqueeze(0)
 
30
  extracted_text = extract_text(image)
31
-
32
- # Explain the extracted text
33
  explanation = explain_text(extracted_text)
34
-
35
- # Display the extracted text and explanation
36
  st.markdown("**Extracted text:**")
37
  st.markdown(extracted_text)
38
 
 
1
  import streamlit as st
2
+ import io
3
+ from PIL import Image
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from easyocr import Reader
7
 
 
8
  ocr_reader = Reader(['en'])
9
  explainer = AutoModelForSequenceClassification.from_pretrained("gpt2")
 
 
10
  def extract_text(image):
11
  return ocr_reader.readtext(image)
 
 
12
  def explain_text(text):
13
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
14
  input_ids = tokenizer.encode(text, return_tensors="pt")
15
  explanation = explainer(input_ids)
16
  return explanation
 
 
17
  st.title("Text Classification Model")
 
 
18
  uploaded_file = st.file_uploader("Upload an image:")
 
 
19
  if uploaded_file is not None:
20
+ image = Image.open(uploaded_file)
21
+
22
  extracted_text = extract_text(image)
 
 
23
  explanation = explain_text(extracted_text)
 
 
24
  st.markdown("**Extracted text:**")
25
  st.markdown(extracted_text)
26