File size: 4,315 Bytes
57be59d c2b8119 57be59d 7529a01 57be59d a14fe5c 57be59d c2b8119 d6abe3a 57be59d a14fe5c 57be59d d6abe3a a14fe5c 57be59d d6abe3a 57be59d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import os
import faiss
import pickle
import numpy as np
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import CLIPProcessor, CLIPModel
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage, UserMessage
from azure.core.credentials import AzureKeyCredential
# Load models
text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Embedding functions
def embed_text(text):
return text_encoder.encode(text)
def embed_image(image_path):
image = Image.open(image_path).convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = clip_model.get_image_features(**inputs)
return outputs.squeeze().cpu().numpy()
# Search + prompt
def semantic_search_and_prompt(query, top_k=5):
if isinstance(query, str) and os.path.exists(query):
query_embedding = embed_image(query).astype('float32').reshape(1, -1)
index = faiss.read_index("image_vector.index")
metadata_path = "image_vector.metadata"
else:
query_embedding = embed_text(query).astype('float32').reshape(1, -1)
index = faiss.read_index("text_vector.index")
metadata_path = "text_vector.metadata"
with open(metadata_path, "rb") as f:
metadata = pickle.load(f)
D, I = index.search(query_embedding, top_k)
top_k_chunks = [dict(metadata[i], score=float(D[0][j])) for j, i in enumerate(I[0])]
context = "\n\n".join([
f"[{chunk['type']} from page {chunk['page']} of {chunk['file']}]:\n{chunk.get('content', '')}"
for chunk in top_k_chunks
])
if isinstance(query, str) and not os.path.exists(query):
user_query = query
else:
user_query = "What is shown in this image?"
prompt = f"""
You are an expert assistant helping users answer questions based on a collection of documents.
Use the provided context chunks to answer the question accurately and clearly.
Context: {context}
Question: {user_query}
Answer:"""
return prompt, top_k_chunks
# Azure LLM setup
endpoint = "https://models.github.ai/inference"
model = "deepseek/DeepSeek-V3-0324"
token = os.getenv("GITHUB_TOKEN")
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
# Main pipeline for Gradio
def handle_query(text_input, image_input):
if image_input:
image_path = "query_image.png"
image_input.save(image_path)
query = image_path
elif text_input:
query = text_input
else:
return "Please provide input", None
prompt, chunks = semantic_search_and_prompt(query)
response = client.complete(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content=prompt),
],
temperature=1.0,
top_p=1.0,
max_tokens=1000,
model=model
)
answer = response.choices[0].message.content
references = "\n".join([
f"- **{chunk['file']}** | Page {chunk['page']} | Type: *{chunk['type']}* | Score: `{chunk['score']:.2f}`"
for chunk in chunks
])
return answer, references
# Gradio UI
def launch_app():
with gr.Blocks() as demo:
gr.Markdown("## 📄🎓Multimodal Chatbot for FAST-NUCES")
with gr.Row():
text_input = gr.Textbox(label="Enter your query")
image_input = gr.Image(label="Upload an image", type="pil")
with gr.Row():
btn = gr.Button("Submit")
btn_clear = gr.Button("Clear")
gr.Markdown("### 🧠 LLM Response:")
answer_output = gr.Markdown()
with gr.Accordion("📚 Source References", open=False):
reference_output = gr.Markdown()
# Submit button
btn.click(fn=handle_query,inputs=[text_input, image_input],outputs=[answer_output, reference_output])
# Clear button
btn_clear.click(fn=lambda: ("", None, "", ""),inputs=[],outputs=[text_input, image_input, answer_output, reference_output])
demo.launch()
if __name__ == "__main__":
launch_app() |