Deevyankar commited on
Commit
8ba2439
·
1 Parent(s): 3c9300b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from typing import List, Dict, Any
4
+
5
+ import gradio as gr
6
+ import chromadb
7
+
8
+ from llama_index.core import VectorStoreIndex, StorageContext
9
+ from llama_index.vector_stores.chroma import ChromaVectorStore
10
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
+ from llama_index.llms.openai import OpenAI
12
+
13
+ COLLECTION_NAME = "neuro_course"
14
+ INDEX = None
15
+
16
+
17
+ def get_persist_dir():
18
+ return "/data/chroma" if os.path.exists("/data") else "storage/chroma"
19
+
20
+
21
+ def processed_text_exists():
22
+ chapter_dir = "processed/chapters"
23
+ return os.path.exists(chapter_dir) and any(
24
+ f.endswith(".txt") for f in os.listdir(chapter_dir)
25
+ )
26
+
27
+
28
+ def vector_db_exists():
29
+ persist_dir = get_persist_dir()
30
+ return os.path.exists(persist_dir) and len(os.listdir(persist_dir)) > 0
31
+
32
+
33
+ def run_extract_if_needed():
34
+ if not processed_text_exists():
35
+ print("No processed chapter text found. Running extraction...")
36
+ subprocess.check_call(["python", "extract_all_pdfs_chapterwise.py"])
37
+ else:
38
+ print("Processed chapter text already exists. Skipping extraction.")
39
+
40
+
41
+ def run_ingest_if_needed():
42
+ if not vector_db_exists():
43
+ print("No vector DB found. Running ingestion...")
44
+ subprocess.check_call(["python", "ingest.py"])
45
+ else:
46
+ print("Vector DB already exists. Skipping ingestion.")
47
+
48
+
49
+ def ensure_everything_ready():
50
+ run_extract_if_needed()
51
+ run_ingest_if_needed()
52
+
53
+
54
+ def load_index():
55
+ persist_dir = get_persist_dir()
56
+
57
+ client = chromadb.PersistentClient(path=persist_dir)
58
+ collection = client.get_or_create_collection(COLLECTION_NAME)
59
+
60
+ vector_store = ChromaVectorStore(chroma_collection=collection)
61
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
62
+
63
+ embed_model = HuggingFaceEmbedding(
64
+ model_name="intfloat/multilingual-e5-base"
65
+ )
66
+
67
+ return VectorStoreIndex.from_vector_store(
68
+ vector_store=vector_store,
69
+ storage_context=storage_context,
70
+ embed_model=embed_model
71
+ )
72
+
73
+
74
+ def get_index():
75
+ global INDEX
76
+ if INDEX is None:
77
+ ensure_everything_ready()
78
+ INDEX = load_index()
79
+ return INDEX
80
+
81
+
82
+ def format_sources(response, max_sources=3):
83
+ output = ""
84
+ if hasattr(response, "source_nodes") and response.source_nodes:
85
+ output += "\n\n---\n### Sources\n"
86
+ for i, sn in enumerate(response.source_nodes[:max_sources], start=1):
87
+ meta = sn.node.metadata or {}
88
+ file_name = meta.get("file_name", "unknown_file")
89
+ snippet = sn.node.get_text()[:250].replace("\n", " ")
90
+ output += f"\n**{i}. {file_name}**\n> {snippet}...\n"
91
+ return output
92
+
93
+
94
+ def respond(
95
+ message: str,
96
+ history: List[Dict[str, Any]],
97
+ model_name: str,
98
+ temperature: float,
99
+ top_k: int,
100
+ show_sources: bool,
101
+ ):
102
+ if history is None:
103
+ history = []
104
+
105
+ if not message or not message.strip():
106
+ return history, ""
107
+
108
+ if not os.getenv("OPENAI_API_KEY"):
109
+ history = history + [{
110
+ "role": "assistant",
111
+ "content": "OPENAI_API_KEY missing. Add it in Hugging Face Space secrets."
112
+ }]
113
+ return history, ""
114
+
115
+ history = history + [{"role": "user", "content": message.strip()}]
116
+
117
+ try:
118
+ index = get_index()
119
+ llm = OpenAI(model=model_name, temperature=float(temperature))
120
+
121
+ query_engine = index.as_query_engine(
122
+ llm=llm,
123
+ similarity_top_k=int(top_k),
124
+ response_mode="compact"
125
+ )
126
+
127
+ prompt = (
128
+ "You are an interactive neurology tutor. "
129
+ "Answer only from the retrieved course material. "
130
+ "If the answer is not found, say: 'Not found in the course material.' "
131
+ "Keep answers concise unless the user asks for detail.\n\n"
132
+ f"Question: {message.strip()}"
133
+ )
134
+
135
+ response = query_engine.query(prompt)
136
+ answer = str(response)
137
+
138
+ if show_sources:
139
+ answer += format_sources(response, max_sources=min(int(top_k), 3))
140
+
141
+ except Exception as e:
142
+ answer = f"Error: {str(e)}"
143
+
144
+ history = history + [{"role": "assistant", "content": answer}]
145
+ return history, ""
146
+
147
+
148
+ def clear_chat():
149
+ return []
150
+
151
+
152
+ with gr.Blocks() as demo:
153
+ gr.Markdown("# 🧠 Neurology Tutor")
154
+ gr.Markdown("Automatic pipeline: PDF extraction → chapter text → vector DB → chatbot")
155
+
156
+ chatbot = gr.Chatbot(height=500, type="messages")
157
+ msg = gr.Textbox(placeholder="Ask a question...", lines=1)
158
+
159
+ with gr.Row():
160
+ model_name = gr.Dropdown(
161
+ ["gpt-4o-mini", "gpt-4.1-mini"],
162
+ value="gpt-4o-mini",
163
+ label="Model"
164
+ )
165
+ temperature = gr.Slider(0.0, 0.8, value=0.2, step=0.1, label="Temperature")
166
+
167
+ with gr.Row():
168
+ top_k = gr.Slider(1, 5, value=3, step=1, label="Top-K Chunks")
169
+ show_sources = gr.Checkbox(value=False, label="Show Sources")
170
+
171
+ clear_btn = gr.Button("Clear Chat")
172
+
173
+ msg.submit(
174
+ respond,
175
+ inputs=[msg, chatbot, model_name, temperature, top_k, show_sources],
176
+ outputs=[chatbot, msg]
177
+ )
178
+
179
+ clear_btn.click(
180
+ clear_chat,
181
+ inputs=[],
182
+ outputs=[chatbot]
183
+ )
184
+
185
+
186
+ if __name__ == "__main__":
187
+ demo.launch()