RaghuCourage9605 commited on
Commit
80144d7
·
verified ·
1 Parent(s): c30c12e

Delete 2_Image_QA.py

Browse files
Files changed (1) hide show
  1. 2_Image_QA.py +0 -160
2_Image_QA.py DELETED
@@ -1,160 +0,0 @@
1
- import streamlit as st
2
- from langchain.schema import Document
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain.vectorstores import FAISS
5
- from langchain_community.docstore import InMemoryDocstore
6
- from langchain_huggingface import HuggingFaceEmbeddings
7
- from langchain_core.messages import HumanMessage
8
- from langchain_cerebras import ChatCerebras
9
- from langchain_mistralai import ChatMistralAI
10
- from langchain_google_genai import ChatGoogleGenerativeAI
11
- from langchain.prompts import ChatPromptTemplate
12
- from langchain.schema import StrOutputParser
13
- from uuid import uuid4
14
- import faiss
15
- import os
16
- from dotenv import load_dotenv
17
- import logging
18
- import httpx
19
- import base64
20
- import asyncio
21
-
22
- # Initialize environment variables and logging
23
- load_dotenv()
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
- # Async function to invoke chain
28
- async def async_invoke_chain(chain, input_data):
29
- loop = asyncio.get_event_loop()
30
- return await loop.run_in_executor(None, chain.invoke, input_data)
31
-
32
- # Initialize session state for messages and models
33
- if "messages" not in st.session_state:
34
- st.session_state.messages = []
35
-
36
- if "models" not in st.session_state:
37
- st.session_state.models = {
38
- "Gemini": ChatGoogleGenerativeAI(
39
- model="gemini-2.0-flash-exp",
40
- temperature=0.8,
41
- verbose=True,
42
- api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY")
43
- ),
44
- "Mistral": ChatMistralAI(
45
- model_name="open-mistral-nemo",
46
- temperature=0.8,
47
- verbose=True
48
- ),
49
- "Llama": ChatCerebras(
50
- model="llama-3.3-70b",
51
- temperature=0.8,
52
- verbose=True,
53
- api_key=os.getenv("CEREBRAS_API_KEY")
54
- )
55
- }
56
-
57
- # Initialize embeddings model
58
- if "embeddings" not in st.session_state:
59
- model_name = "sentence-transformers/all-mpnet-base-v2"
60
- model_kwargs = {'device': 'cpu'}
61
- encode_kwargs = {'normalize_embeddings': False}
62
- st.session_state.embeddings = HuggingFaceEmbeddings(
63
- model_name=model_name,
64
- model_kwargs=model_kwargs,
65
- encode_kwargs=encode_kwargs
66
- )
67
-
68
- st.header("📸📈📊 ֎ Image Content Analysis and Question Answering")
69
-
70
- # Brief overview for image content analysis
71
- description = """
72
- Upload an image, and the AI will analyze its content and answer your questions.
73
- It can interpret various types of images including:
74
- - General imagery (objects, people, scenes)
75
- - Diagrams, graphs, and data visualizations
76
- - Scientific and medical images
77
- - Text-based images (documents, screenshots)
78
- """
79
-
80
- # Display the brief description
81
- st.write(description)
82
-
83
- # File upload and URL input
84
- st.header("Upload Image for Question Answering")
85
- uploaded_file = st.file_uploader("Upload an image (.jpeg, .jpg, .png, etc.):", type=["jpeg", "jpg", "png"])
86
-
87
- st.header("Or Enter the Image URL :")
88
- image_url = st.text_input("Enter the image URL")
89
-
90
- image_data = None
91
-
92
- if uploaded_file:
93
- st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
94
- image_data = base64.b64encode(uploaded_file.read()).decode("utf-8")
95
- elif image_url:
96
- try:
97
- with httpx.Client() as client:
98
- response = client.get(image_url)
99
- response.raise_for_status()
100
- st.image(response.content, caption="Image from URL", use_column_width=True)
101
- image_data = base64.b64encode(response.content).decode("utf-8")
102
- except Exception as e:
103
- st.error(f"Error fetching image from URL: {e}")
104
-
105
- if image_data:
106
- message = HumanMessage(content=[{
107
- "type": "text", "text": "Describe what is in the image in detail."
108
- }, {
109
- "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}
110
- }])
111
-
112
- # Generate response from the model
113
- response = asyncio.run(async_invoke_chain(st.session_state.models["Gemini"], [message]))
114
- knowledge = [Document(page_content=response.content)]
115
-
116
- # Split text into chunks for indexing
117
- text_splitter = RecursiveCharacterTextSplitter(separators="\n\n", chunk_size=1500, chunk_overlap=200)
118
- chunks = text_splitter.split_documents(knowledge)
119
-
120
- # Create FAISS IndexHNSWFlat for indexing image embeddings
121
- index = faiss.IndexFlatL2(len(st.session_state.embeddings.embed_query("hello world")))
122
-
123
- # Create FAISS vector store for document retrieval
124
- vector_store = FAISS(
125
- embedding_function=st.session_state.embeddings,
126
- index=index,
127
- docstore=InMemoryDocstore(),
128
- index_to_docstore_id={},
129
- )
130
-
131
- # Generate unique IDs and add documents to the store
132
- ids = [str(uuid4()) for _ in range(len(chunks))]
133
- vector_store.add_documents(documents=chunks, ids=ids)
134
-
135
- # Update the mapping between FAISS index and document IDs
136
- for idx, doc_id in enumerate(ids):
137
- vector_store.index_to_docstore_id[idx] = doc_id
138
-
139
- # Create image retriever with the FAISS index
140
- image_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 6})
141
-
142
- def get_retrieved_context(query):
143
- retrieved_documents = image_retriever.get_relevant_documents(query)
144
- return "\n".join(doc.page_content for doc in retrieved_documents)
145
-
146
- # User query for image QA
147
- user_input = st.chat_input("Ask a question about the image:")
148
-
149
- prompt = ChatPromptTemplate.from_messages([(
150
- "system", "You are an expert in analyzing images. Use the context: {context} to answer the query."
151
- ), ("human", "{question}")])
152
-
153
- if user_input:
154
- st.session_state.messages.append({"role": "user", "content": user_input})
155
- qa_chain = prompt | st.session_state.models["Mistral"] | StrOutputParser()
156
- context = get_retrieved_context(user_input)
157
- response_message = asyncio.run(async_invoke_chain(qa_chain, {"question": user_input, "context": context}))
158
- st.session_state.messages.append({"role": "assistant", "content": response_message})
159
- for message in st.session_state.messages:
160
- st.chat_message(message["role"]).markdown(message["content"])