Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| from openai import OpenAI | |
| from PIL import Image | |
| class IntegratedChatSystem: | |
| def __init__(self, api_key: str, model_name: str, embedding_dim: int = 384): | |
| self.api_key = api_key | |
| self.model_name = model_name | |
| self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| self.embedding_dim = embedding_dim | |
| self.index = faiss.IndexFlatIP(embedding_dim) | |
| self.metadata = [] | |
| self.client = OpenAI(api_key=api_key) | |
| def _add_to_index(self, vector: np.ndarray, metadata: dict): | |
| self.index.add(vector) | |
| self.metadata.append(metadata) | |
| def add_image(self, image_path: str, context_text: str): | |
| filename = os.path.basename(image_path) | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| embedding = self.embedding_model.encode(context_text) | |
| embedding = np.expand_dims(embedding, axis=0) | |
| self._add_to_index(embedding, {"filepath": filename, "context": context_text}) | |
| def chat(self, user_message: str, similarity_threshold: float = 0.7, top_k: int = 3): | |
| message_embedding = self.embedding_model.encode(user_message) | |
| message_embedding = np.expand_dims(message_embedding, axis=0) | |
| distances, indices = self.index.search(message_embedding, top_k) | |
| relevant_images = [ | |
| self.metadata[i] for i, distance in zip(indices[0], distances[0]) | |
| if i != -1 and distance >= similarity_threshold | |
| ] | |
| system_prompt = """You are an assistant chatbot. You should help the user by answering their question.""" | |
| enhanced_message = user_message | |
| if relevant_images: | |
| image_contexts = "\n".join(f"- {img['context']}" for img in relevant_images) | |
| enhanced_message = f"{user_message}\n\nContext from relevant images:\n{image_contexts}" | |
| try: | |
| completion = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": enhanced_message} | |
| ] | |
| ) | |
| response = completion.choices[0].message.content | |
| return { | |
| "response": response, | |
| "images": relevant_images if relevant_images else None | |
| } | |
| except Exception as e: | |
| print(f"Error calling OpenAI API: {str(e)}") | |
| return { | |
| "response": "I apologize, but I encountered an error processing your request.", | |
| "images": None | |
| } | |
| # Initialize the chat system | |
| api_key = "" | |
| model_name = "ft:gpt-3.5-turbo-0125:brenin::AlVMkeUb" | |
| chat_system = IntegratedChatSystem(api_key, model_name) | |
| # Add images | |
| image_folder = "images" | |
| chat_system.add_image(os.path.join(image_folder, "sequence diagram.png"), "A diagram showing the sequence of how it is supposed to work. What is the sequence?") | |
| chat_system.add_image(os.path.join(image_folder, "UX workflow.png"), "A flowchart of showing the UX workflow.What is the UX workflow") | |
| chat_system.add_image(os.path.join(image_folder, "UI.png"), "A diagram the UI. What is the UI? ") | |
| chat_system.add_image(os.path.join(image_folder, "workflow.png"), "A flowchart of showing the workflow. What is the workflow?") | |
| # Streamlit UI | |
| st.title("Chat with Integrated Image Context") | |
| st.sidebar.title("Chat System") | |
| user_message = st.text_input("Your message:", placeholder="Type your message here...") | |
| if st.button("Send"): | |
| if user_message.strip(): | |
| result = chat_system.chat(user_message) | |
| st.write(f"**Assistant:** {result['response']}") | |
| if result["images"]: | |
| st.write("Relevant Images:") | |
| for img in result["images"]: | |
| image_path = os.path.join(image_folder, img["filepath"]) | |
| if os.path.exists(image_path): | |
| st.image(Image.open(image_path), caption=img["context"]) | |
| else: | |
| st.write(f"Image not found: {img['filepath']}") | |
| else: | |
| st.error("Please enter a message.") | |