mkoot007 commited on
Commit
3c97a0a
·
1 Parent(s): 79be51e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -2
app.py CHANGED
@@ -5,22 +5,48 @@ import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from easyocr import Reader
7
 
 
8
  ocr_reader = Reader(['en'])
9
  explainer = AutoModelForSequenceClassification.from_pretrained("gpt2")
 
 
10
  def extract_text(image):
11
  return ocr_reader.readtext(image)
 
 
12
  def explain_text(text):
13
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
14
- input_ids = tokenizer.encode(text, return_tensors="pt")
15
- explanation = explainer(input_ids)
 
 
 
 
 
 
16
  return explanation
 
 
17
  st.title("Text Classification Model")
 
 
18
  uploaded_file = st.file_uploader("Upload an image:")
 
 
19
  if uploaded_file is not None:
 
20
  image = Image.open(uploaded_file)
21
 
 
 
 
 
22
  extracted_text = extract_text(image)
 
 
23
  explanation = explain_text(extracted_text)
 
 
24
  st.markdown("**Extracted text:**")
25
  st.markdown(extracted_text)
26
 
 
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from easyocr import Reader
7
 
8
+ # Load the OCR model and text explanation model (gpt-2 as an example)
9
  ocr_reader = Reader(['en'])
10
  explainer = AutoModelForSequenceClassification.from_pretrained("gpt2")
11
+
12
+ # Define a function to extract text from an image
13
  def extract_text(image):
14
  return ocr_reader.readtext(image)
15
+
16
+ # Define a function to explain the extracted text
17
  def explain_text(text):
18
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
19
+
20
+ # Encode the text and convert to PyTorch tensors
21
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
22
+
23
+ input_ids = inputs["input_ids"]
24
+ attention_mask = inputs["attention_mask"]
25
+
26
+ explanation = explainer(input_ids, attention_mask=attention_mask)
27
  return explanation
28
+
29
+ # Create a Streamlit layout
30
  st.title("Text Classification Model")
31
+
32
+ # Allow users to upload an image
33
  uploaded_file = st.file_uploader("Upload an image:")
34
+
35
+ # Extract text from the uploaded image
36
  if uploaded_file is not None:
37
+ # Read the uploaded image
38
  image = Image.open(uploaded_file)
39
 
40
+ # Process the image and convert to NumPy array if necessary
41
+ # image = process_image(image)
42
+
43
+ # Extract text from the image
44
  extracted_text = extract_text(image)
45
+
46
+ # Explain the extracted text
47
  explanation = explain_text(extracted_text)
48
+
49
+ # Display the extracted text and explanation
50
  st.markdown("**Extracted text:**")
51
  st.markdown(extracted_text)
52