vinzur commited on
Commit
57be59d
·
verified ·
1 Parent(s): 69bd558

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import faiss
4
+ import pickle
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from sentence_transformers import SentenceTransformer
9
+ from transformers import CLIPProcessor, CLIPModel
10
+ from azure.ai.inference import ChatCompletionsClient
11
+ from azure.ai.inference.models import SystemMessage, UserMessage
12
+ from azure.core.credentials import AzureKeyCredential
13
+
14
+ # Load models
15
+ text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
16
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
17
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
18
+
19
+ # Embedding functions
20
+ def embed_text(text):
21
+ return text_encoder.encode(text)
22
+
23
+ def embed_image(image_path):
24
+ image = Image.open(image_path).convert("RGB")
25
+ inputs = clip_processor(images=image, return_tensors="pt")
26
+ with torch.no_grad():
27
+ outputs = clip_model.get_image_features(**inputs)
28
+ return outputs.squeeze().cpu().numpy()
29
+
30
+ # Search + prompt
31
+ def semantic_search_and_prompt(query, top_k=5):
32
+ if isinstance(query, str) and os.path.exists(query):
33
+ query_embedding = embed_image(query).astype('float32').reshape(1, -1)
34
+ index = faiss.read_index("image_vector.index")
35
+ metadata_path = "image_vector.metadata"
36
+ else:
37
+ query_embedding = embed_text(query).astype('float32').reshape(1, -1)
38
+ index = faiss.read_index("text_vector.index")
39
+ metadata_path = "text_vector.metadata"
40
+
41
+ with open(metadata_path, "rb") as f:
42
+ metadata = pickle.load(f)
43
+
44
+ D, I = index.search(query_embedding, top_k)
45
+ top_k_chunks = [dict(metadata[i], score=float(D[0][j])) for j, i in enumerate(I[0])]
46
+
47
+ context = "\n\n".join([
48
+ f"[{chunk['type']} from page {chunk['page']} of {chunk['file']}]:\n{chunk.get('content', '')}"
49
+ for chunk in top_k_chunks
50
+ ])
51
+
52
+ if isinstance(query, str) and not os.path.exists(query):
53
+ user_query = query
54
+ else:
55
+ user_query = "What is shown in this image?"
56
+
57
+ prompt = f"""
58
+ You are an expert assistant helping users answer questions based on a collection of documents.
59
+ Use the provided context chunks to answer the question accurately and clearly.
60
+
61
+ Context: {context}
62
+
63
+ Question: {user_query}
64
+ Answer:"""
65
+
66
+ return prompt, top_k_chunks
67
+
68
+ # Azure LLM setup
69
+ endpoint = "https://models.github.ai/inference"
70
+ model = "deepseek/DeepSeek-V3-0324"
71
+ token = os.environ["GITHUB_TOKEN"]
72
+ client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
73
+
74
+ # Main pipeline for Gradio
75
+ def handle_query(text_input, image_input):
76
+ if image_input:
77
+ image_path = "query_image.png"
78
+ image_input.save(image_path)
79
+ query = image_path
80
+ elif text_input:
81
+ query = text_input
82
+ else:
83
+ return "Please provide input", None
84
+
85
+ prompt, chunks = semantic_search_and_prompt(query)
86
+
87
+ response = client.complete(
88
+ messages=[
89
+ SystemMessage("You are a helpful assistant."),
90
+ UserMessage(prompt),
91
+ ],
92
+ temperature=1.0,
93
+ top_p=1.0,
94
+ max_tokens=1000,
95
+ model=model
96
+ )
97
+
98
+ answer = response.choices[0].message.content
99
+ references = "\n".join([
100
+ f"{chunk['file']} | Page {chunk['page']} | Type: {chunk['type']} | Score: {chunk['score']:.2f}"
101
+ for chunk in chunks
102
+ ])
103
+
104
+ return answer, references
105
+
106
+ # Gradio UI
107
+ def launch_app():
108
+ with gr.Blocks() as demo:
109
+ gr.Markdown("## 📄 Semantic Search + Chat Interface")
110
+ with gr.Row():
111
+ text_input = gr.Textbox(label="Enter your query")
112
+ image_input = gr.Image(label="Upload an image", type="pil")
113
+ with gr.Row():
114
+ btn = gr.Button("Submit")
115
+ answer_output = gr.Textbox(label="LLM Response", lines=8)
116
+ reference_output = gr.Textbox(label="Source References", lines=6)
117
+
118
+ btn.click(fn=handle_query, inputs=[text_input, image_input], outputs=[answer_output, reference_output])
119
+
120
+ demo.launch()
121
+
122
+ if __name__ == "__main__":
123
+ launch_app()