mkoot007 commited on
Commit
f11fbf2
·
1 Parent(s): 4560624

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -20
app.py CHANGED
@@ -5,42 +5,26 @@ import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from easyocr import Reader
7
 
8
- # Load the OCR model and text explanation model (GPT-2 as an example)
9
- ocr_reader = Reader(['en'])
10
- text_generator = AutoModelForCausalLM.from_pretrained("gpt2")
11
- text_tokenizer = AutoTokenizer.from_pretrained("gpt2")
12
 
13
- # Define a function to extract text from an image
 
 
14
  def extract_text(image):
15
  return ocr_reader.readtext(image)
16
-
17
- # Define a function to explain the extracted text
18
  def explain_text(text):
19
- # Generate an explanation using the text generation model (GPT-2)
20
  input_ids = text_tokenizer.encode(text, return_tensors="pt")
21
- explanation_ids = text_generator.generate(input_ids, max_length=50, num_return_sequences=1)
22
  explanation = text_tokenizer.decode(explanation_ids[0], skip_special_tokens=True)
23
  return explanation
24
 
25
- # Create a Streamlit layout
26
  st.title("Text Classification Model")
27
-
28
- # Allow users to upload an image
29
  uploaded_file = st.file_uploader("Upload an image:")
30
 
31
- # Extract text from the uploaded image
32
  if uploaded_file is not None:
33
- # Read the uploaded image
34
  image = Image.open(uploaded_file)
35
-
36
- # Extract text from the image
37
  ocr_results = extract_text(image)
38
  extracted_text = " ".join([res[1] for res in ocr_results])
39
-
40
- # Explain the extracted text
41
  explanation = explain_text(extracted_text)
42
-
43
- # Display the extracted text and explanation
44
  st.markdown("**Extracted text:**")
45
  st.markdown(extracted_text)
46
 
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from easyocr import Reader
7
 
 
 
 
 
8
 
9
+ ocr_reader = Reader(['en'])
10
+ text_generator = AutoModelForCausalLM.from_pretrained("gpt3")
11
+ text_tokenizer = AutoTokenizer.from_pretrained("gpt3")
12
  def extract_text(image):
13
  return ocr_reader.readtext(image)
 
 
14
  def explain_text(text):
 
15
  input_ids = text_tokenizer.encode(text, return_tensors="pt")
16
+ explanation_ids = text_generator.generate(input_ids, max_length=100, num_return_sequences=1)
17
  explanation = text_tokenizer.decode(explanation_ids[0], skip_special_tokens=True)
18
  return explanation
19
 
 
20
  st.title("Text Classification Model")
 
 
21
  uploaded_file = st.file_uploader("Upload an image:")
22
 
 
23
  if uploaded_file is not None:
 
24
  image = Image.open(uploaded_file)
 
 
25
  ocr_results = extract_text(image)
26
  extracted_text = " ".join([res[1] for res in ocr_results])
 
 
27
  explanation = explain_text(extracted_text)
 
 
28
  st.markdown("**Extracted text:**")
29
  st.markdown(extracted_text)
30