| import streamlit as st |
| from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering |
| from PIL import Image |
| import easyocr |
| import os |
| from groq import Groq |
|
|
| |
| def extract_text_from_image(image): |
| reader = easyocr.Reader(['en']) |
| result = reader.readtext(image) |
| extracted_text = " ".join([detection[1] for detection in result]) |
| return extracted_text |
|
|
| |
| @st.cache_resource |
| def load_qa_model(): |
| model_name = "distilbert/distilbert-base-cased-distilled-squad" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForQuestionAnswering.from_pretrained(model_name) |
| nlp = pipeline('question-answering', model=model, tokenizer=tokenizer) |
| return nlp |
|
|
| def answer_question(context, question, qa_model): |
| result = qa_model({'question': question, 'context': context}) |
| return result['answer'] |
|
|
| |
| def groq_chat(prompt): |
| try: |
| client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
| chat_completion = client.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| model="llama-3.3-70b-versatile", |
| ) |
| return chat_completion.choices[0].message.content |
| except Exception as e: |
| return f"Error using Groq API: {e}. Please ensure GROQ_API_KEY is set correctly." |
|
|
| |
| def main(): |
| st.title("Image Text & Question Answering Chatbot") |
|
|
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_file is not None: |
| image = Image.open(uploaded_file) |
| st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
| if st.button("Extract Text and Enable Question Answering"): |
| with st.spinner("Extracting text..."): |
| extracted_text = extract_text_from_image(image) |
| st.write("Extracted Text:") |
| st.write(extracted_text) |
|
|
| qa_model = load_qa_model() |
|
|
| question = st.text_input("Ask a question about the image text:") |
| if st.button("Answer"): |
| if question: |
| with st.spinner("Answering..."): |
| answer = answer_question(extracted_text, question, qa_model) |
| st.write("Answer:", answer) |
| else: |
| st.warning("Please enter a question.") |
|
|
| |
| st.subheader("General Chat (Powered by Groq)") |
| groq_prompt = st.text_input("Enter your message:") |
| if st.button("Send"): |
| if groq_prompt: |
| with st.spinner("Generating response..."): |
| groq_response = groq_chat(groq_prompt) |
| st.write("Response:", groq_response) |
| else: |
| st.warning("Please enter a message.") |
|
|
| if __name__ == "__main__": |
| main() |