|
|
import gradio as gr |
|
|
import os |
|
|
import faiss |
|
|
import pickle |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
from azure.ai.inference import ChatCompletionsClient |
|
|
from azure.ai.inference.models import SystemMessage, UserMessage |
|
|
from azure.core.credentials import AzureKeyCredential |
|
|
|
|
|
|
|
|
text_encoder = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
|
|
|
def embed_text(text): |
|
|
return text_encoder.encode(text) |
|
|
|
|
|
def embed_image(image_path): |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
inputs = clip_processor(images=image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = clip_model.get_image_features(**inputs) |
|
|
return outputs.squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
def semantic_search_and_prompt(query, top_k=5): |
|
|
if isinstance(query, str) and os.path.exists(query): |
|
|
query_embedding = embed_image(query).astype('float32').reshape(1, -1) |
|
|
index = faiss.read_index("image_vector.index") |
|
|
metadata_path = "image_vector.metadata" |
|
|
else: |
|
|
query_embedding = embed_text(query).astype('float32').reshape(1, -1) |
|
|
index = faiss.read_index("text_vector.index") |
|
|
metadata_path = "text_vector.metadata" |
|
|
|
|
|
with open(metadata_path, "rb") as f: |
|
|
metadata = pickle.load(f) |
|
|
|
|
|
D, I = index.search(query_embedding, top_k) |
|
|
top_k_chunks = [dict(metadata[i], score=float(D[0][j])) for j, i in enumerate(I[0])] |
|
|
|
|
|
context = "\n\n".join([ |
|
|
f"[{chunk['type']} from page {chunk['page']} of {chunk['file']}]:\n{chunk.get('content', '')}" |
|
|
for chunk in top_k_chunks |
|
|
]) |
|
|
|
|
|
if isinstance(query, str) and not os.path.exists(query): |
|
|
user_query = query |
|
|
else: |
|
|
user_query = "What is shown in this image?" |
|
|
|
|
|
prompt = f""" |
|
|
You are an expert assistant helping users answer questions based on a collection of documents. |
|
|
Use the provided context chunks to answer the question accurately and clearly. |
|
|
|
|
|
Context: {context} |
|
|
|
|
|
Question: {user_query} |
|
|
Answer:""" |
|
|
|
|
|
return prompt, top_k_chunks |
|
|
|
|
|
|
|
|
endpoint = "https://models.github.ai/inference" |
|
|
model = "deepseek/DeepSeek-V3-0324" |
|
|
token = os.getenv("GITHUB_TOKEN") |
|
|
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) |
|
|
|
|
|
|
|
|
def handle_query(text_input, image_input): |
|
|
if image_input: |
|
|
image_path = "query_image.png" |
|
|
image_input.save(image_path) |
|
|
query = image_path |
|
|
elif text_input: |
|
|
query = text_input |
|
|
else: |
|
|
return "Please provide input", None |
|
|
|
|
|
prompt, chunks = semantic_search_and_prompt(query) |
|
|
|
|
|
response = client.complete( |
|
|
messages=[ |
|
|
SystemMessage(content="You are a helpful assistant."), |
|
|
UserMessage(content=prompt), |
|
|
], |
|
|
temperature=1.0, |
|
|
top_p=1.0, |
|
|
max_tokens=1000, |
|
|
model=model |
|
|
) |
|
|
|
|
|
answer = response.choices[0].message.content |
|
|
references = "\n".join([ |
|
|
f"- **{chunk['file']}** | Page {chunk['page']} | Type: *{chunk['type']}* | Score: `{chunk['score']:.2f}`" |
|
|
for chunk in chunks |
|
|
]) |
|
|
|
|
|
return answer, references |
|
|
|
|
|
|
|
|
def launch_app(): |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 📄🎓Multimodal Chatbot for FAST-NUCES") |
|
|
|
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox(label="Enter your query") |
|
|
image_input = gr.Image(label="Upload an image", type="pil") |
|
|
|
|
|
with gr.Row(): |
|
|
btn = gr.Button("Submit") |
|
|
btn_clear = gr.Button("Clear") |
|
|
|
|
|
gr.Markdown("### 🧠 LLM Response:") |
|
|
answer_output = gr.Markdown() |
|
|
|
|
|
with gr.Accordion("📚 Source References", open=False): |
|
|
reference_output = gr.Markdown() |
|
|
|
|
|
|
|
|
btn.click(fn=handle_query,inputs=[text_input, image_input],outputs=[answer_output, reference_output]) |
|
|
|
|
|
|
|
|
btn_clear.click(fn=lambda: ("", None, "", ""),inputs=[],outputs=[text_input, image_input, answer_output, reference_output]) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
launch_app() |