RizwanSajad commited on
Commit
d2edfae
·
verified ·
1 Parent(s): 4cc8eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -34
app.py CHANGED
@@ -6,25 +6,23 @@ import pytesseract
6
  from transformers import AutoTokenizer, AutoModel
7
  import faiss
8
  import numpy as np
9
- import torch
10
  from groq import Groq
11
 
12
- # Configure Streamlit app
13
  st.title("RAG-Based Application")
14
- st.write("Upload an image to extract and query content.")
15
 
16
- # Initialize Groq API
17
  def get_groq_client():
18
  return Groq(api_key=os.environ.get("GROQ_API_KEY"))
19
 
20
-
21
- # Load embedding model
22
  st.write("Loading embedding model...")
23
  try:
24
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
25
  model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
26
  except Exception as e:
27
- st.error(f"Error loading model: {e}")
28
 
29
  # Initialize FAISS index
30
  dimension = model.config.hidden_size
@@ -35,61 +33,55 @@ def extract_text_from_image(image_path):
35
  try:
36
  return pytesseract.image_to_string(Image.open(image_path))
37
  except pytesseract.TesseractNotFoundError:
38
- st.error("Tesseract is not installed. It is required for text extraction.")
39
  return ""
40
 
41
  def get_embeddings(text_chunks):
42
- """Get embeddings for text chunks."""
43
  inputs = tokenizer(text_chunks, return_tensors="pt", padding=True, truncation=True)
44
  with torch.no_grad():
45
- outputs = model(**inputs)
46
- embeddings = outputs.last_hidden_state.mean(dim=1).numpy()
47
  return embeddings
48
 
49
- def query_groq(question, model="llama-3.3-70b-versatile"):
50
- """Query Groq model for a response."""
51
  try:
52
- client = Groq(api_key=GROQ_API_KEY)
53
  response = client.chat.completions.create(
54
  messages=[{"role": "user", "content": question}],
55
- model=model,
56
  )
57
  return response.choices[0].message.content
58
  except Exception as e:
59
- st.error(f"Error querying Groq model: {e}")
60
  return ""
61
 
62
- # File uploader for image input
63
  uploaded_file = st.file_uploader("Upload an image (JPG, PNG):", type=["jpg", "jpeg", "png"])
64
-
65
  if uploaded_file:
66
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
67
  temp_file.write(uploaded_file.read())
68
  temp_image_path = temp_file.name
69
 
70
- # Extract text from the image
71
  st.write("Extracting text from the uploaded image...")
72
  extracted_text = extract_text_from_image(temp_image_path)
73
  st.text_area("Extracted Text:", extracted_text, height=200)
74
 
75
  if extracted_text.strip():
76
- # Chunk and process the text
77
- st.write("Processing text into chunks...")
78
- text_chunks = [extracted_text[i : i + 512] for i in range(0, len(extracted_text), 512)]
79
-
80
- try:
81
- embeddings = get_embeddings(text_chunks)
82
- st.write("Storing data in FAISS database...")
83
- index.add(np.array(embeddings))
84
- st.success("Data processed and stored successfully!")
85
- except Exception as e:
86
- st.error(f"Error during embedding creation: {e}")
87
-
88
- # Query interface
89
- user_question = st.text_input("Ask a question based on the uploaded content:")
90
  if user_question:
91
  answer = query_groq(user_question)
92
  st.write("Answer from Groq:")
93
  st.write(answer)
94
  else:
95
- st.warning("No text was extracted from the image. Please try again with a different file.")
 
6
  from transformers import AutoTokenizer, AutoModel
7
  import faiss
8
  import numpy as np
 
9
  from groq import Groq
10
 
11
+ # Configure the application
12
  st.title("RAG-Based Application")
13
+ st.write("Upload an image, and extract and query its content.")
14
 
15
+ # Groq API setup
16
  def get_groq_client():
17
  return Groq(api_key=os.environ.get("GROQ_API_KEY"))
18
 
19
+ # Model for embedding generation
 
20
  st.write("Loading embedding model...")
21
  try:
22
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
23
  model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
24
  except Exception as e:
25
+ st.error(f"Failed to load embedding model: {e}")
26
 
27
  # Initialize FAISS index
28
  dimension = model.config.hidden_size
 
33
  try:
34
  return pytesseract.image_to_string(Image.open(image_path))
35
  except pytesseract.TesseractNotFoundError:
36
+ st.error("Tesseract is not installed. Install it via the setup script.")
37
  return ""
38
 
39
  def get_embeddings(text_chunks):
40
+ """Generate embeddings for text chunks using the model."""
41
  inputs = tokenizer(text_chunks, return_tensors="pt", padding=True, truncation=True)
42
  with torch.no_grad():
43
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1).numpy()
 
44
  return embeddings
45
 
46
+ def query_groq(question):
47
+ """Query the Groq API to generate answers."""
48
  try:
 
49
  response = client.chat.completions.create(
50
  messages=[{"role": "user", "content": question}],
51
+ model="llama-3.3-70b-versatile"
52
  )
53
  return response.choices[0].message.content
54
  except Exception as e:
55
+ st.error(f"Error querying Groq API: {e}")
56
  return ""
57
 
58
+ # File uploader
59
  uploaded_file = st.file_uploader("Upload an image (JPG, PNG):", type=["jpg", "jpeg", "png"])
 
60
  if uploaded_file:
61
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
62
  temp_file.write(uploaded_file.read())
63
  temp_image_path = temp_file.name
64
 
65
+ # Extract text from image
66
  st.write("Extracting text from the uploaded image...")
67
  extracted_text = extract_text_from_image(temp_image_path)
68
  st.text_area("Extracted Text:", extracted_text, height=200)
69
 
70
  if extracted_text.strip():
71
+ # Chunk text for embeddings
72
+ text_chunks = [extracted_text[i:i+512] for i in range(0, len(extracted_text), 512)]
73
+
74
+ # Generate embeddings
75
+ embeddings = get_embeddings(text_chunks)
76
+ st.write("Storing extracted data in FAISS database...")
77
+ index.add(np.array(embeddings))
78
+ st.success("Text processed and stored successfully!")
79
+
80
+ # Question input for Groq
81
+ user_question = st.text_input("Ask a question based on the uploaded image content:")
 
 
 
82
  if user_question:
83
  answer = query_groq(user_question)
84
  st.write("Answer from Groq:")
85
  st.write(answer)
86
  else:
87
+ st.warning("No text could be extracted from the image. Try another file.")