mkoot007 commited on
Commit
ccaa1a2
·
1 Parent(s): ceb18fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -11
app.py CHANGED
@@ -1,28 +1,24 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import io
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  from easyocr import Reader
6
 
7
- # Load the OCR model and text explanation model
8
  ocr_reader = Reader(['en'])
 
 
 
9
 
10
- # Load the text explanation model
11
- text_generator = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
12
- text_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
13
-
14
- # Define a function to extract text from an image using OCR
15
  def extract_text(image):
16
  return ocr_reader.readtext(image)
17
 
18
- # Define a function to explain the extracted text using text generation
19
- def explain_text(text, text_generator, text_tokenizer):
20
  # Extracted text
21
  extracted_text = " ".join([res[1] for res in text])
22
 
23
  # Generate an explanation using the text explanation model
24
  input_ids = text_tokenizer.encode(extracted_text, return_tensors="pt")
25
- explanation_ids = text_generator.generate(input_ids, max_length=100, num_return_sequences=1)
26
  explanation = text_tokenizer.decode(explanation_ids[0], skip_special_tokens=True)
27
 
28
  return explanation
@@ -37,7 +33,7 @@ uploaded_file = st.file_uploader("Upload an image:")
37
  if uploaded_file is not None:
38
  image = Image.open(uploaded_file)
39
  ocr_results = extract_text(image)
40
- explanation = explain_text(ocr_results, text_generator, text_tokenizer)
41
 
42
  st.markdown("**Extracted text:**")
43
  st.markdown(" ".join([res[1] for res in ocr_results]))
 
1
  import streamlit as st
2
  from PIL import Image
3
  import io
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
5
  from easyocr import Reader
6
 
 
7
  ocr_reader = Reader(['en'])
8
+ text_generator = AutoModelForCausalLM.from_pretrained("bart")
9
+ text_tokenizer = AutoTokenizer.from_pretrained("bart")
10
+ explainer = AutoModelForSeq2SeqLM.from_pretrained("bart-explainer")
11
 
 
 
 
 
 
12
  def extract_text(image):
13
  return ocr_reader.readtext(image)
14
 
15
+ def explain_text(text, explainer, text_tokenizer):
 
16
  # Extracted text
17
  extracted_text = " ".join([res[1] for res in text])
18
 
19
  # Generate an explanation using the text explanation model
20
  input_ids = text_tokenizer.encode(extracted_text, return_tensors="pt")
21
+ explanation_ids = explainer.generate(input_ids, max_length=100, num_return_sequences=1)
22
  explanation = text_tokenizer.decode(explanation_ids[0], skip_special_tokens=True)
23
 
24
  return explanation
 
33
  if uploaded_file is not None:
34
  image = Image.open(uploaded_file)
35
  ocr_results = extract_text(image)
36
+ explanation = explain_text(ocr_results, explainer, text_tokenizer)
37
 
38
  st.markdown("**Extracted text:**")
39
  st.markdown(" ".join([res[1] for res in ocr_results]))