Suriyaaan commited on
Commit
ea1e277
·
verified ·
1 Parent(s): fcf0a61

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import streamlit as st
4
+ import numpy as np
5
+ import faiss
6
+ from sentence_transformers import SentenceTransformer
7
+ from typing import List, Tuple
8
+ from pathlib import Path
9
+ from langchain.schema import SystemMessage, HumanMessage
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+
13
+ import google.generativeai as genai # official Google GenAI SDK
14
+
15
+ # -------------------------
16
+ # Configuration
17
+ # -------------------------
18
+ st.set_page_config(page_title="SystemVerilog Chatbot (Modern LangChain)", layout="wide")
19
+
20
+ # Set API keys from environment (ensure these are set in your environment / .env)
21
+ GENAI_API_KEY = os.getenv("googleapikey")
22
+ HF_TOKEN = os.getenv("hf") # if you need HuggingFace for other tasks
23
+
24
+ if not GENAI_API_KEY:
25
+ st.error("Please set environment variable googleapikey")
26
+ st.stop()
27
+
28
+ genai.configure(api_key=GENAI_API_KEY)
29
+
30
+ # -------------------------
31
+ # Utilities: text load + splitting
32
+ # -------------------------
33
+ def load_text(path: str) -> str:
34
+ p = Path(path)
35
+ if not p.exists():
36
+ st.error(f"File not found: {path}")
37
+ st.stop()
38
+ return p.read_text(encoding="utf-8")
39
+
40
+ def simple_recursive_split(text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
41
+ """A simple recursive character text splitter. Keeps splits at whitespace when possible."""
42
+ chunks = []
43
+ start = 0
44
+ text_len = len(text)
45
+ while start < text_len:
46
+ end = min(start + chunk_size, text_len)
47
+ # try to not cut a word
48
+ if end < text_len:
49
+ last_space = text.rfind(" ", start, end)
50
+ if last_space > start:
51
+ end = last_space
52
+ chunk = text[start:end].strip()
53
+ if chunk:
54
+ chunks.append(chunk)
55
+ start = end - overlap
56
+ if start < 0:
57
+ start = 0
58
+ if start >= text_len:
59
+ break
60
+ return chunks
61
+
62
+ # -------------------------
63
+ # Embeddings + FAISS index
64
+ # -------------------------
65
+ @st.experimental_singleton
66
+ def load_embedding_model(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
67
+ return SentenceTransformer(model_name)
68
+
69
+ @st.experimental_singleton
70
+ def build_faiss_index(texts: List[str], embed_model: SentenceTransformer) -> Tuple[faiss.IndexFlatIP, np.ndarray]:
71
+ # Encode and normalize
72
+ embeddings = embed_model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
73
+ # normalize for cosine similarity
74
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
75
+ norms[norms == 0] = 1.0
76
+ embeddings = embeddings / norms
77
+ d = embeddings.shape[1]
78
+ index = faiss.IndexFlatIP(d) # inner product on normalized vectors = cosine similarity
79
+ index.add(embeddings.astype("float32"))
80
+ return index, embeddings
81
+
82
+ def retrieve_top_k(query: str, k: int, index: faiss.IndexFlatIP, embed_model: SentenceTransformer, texts: List[str]) -> List[Tuple[int, float, str]]:
83
+ q_emb = embed_model.encode([query], convert_to_numpy=True)
84
+ q_emb = q_emb / np.linalg.norm(q_emb, axis=1, keepdims=True)
85
+ D, I = index.search(q_emb.astype("float32"), k)
86
+ results = []
87
+ for idx, score in zip(I[0], D[0]):
88
+ results.append((int(idx), float(score), texts[idx]))
89
+ return results
90
+
91
+ # -------------------------
92
+ # Google GenAI call (wrap it so easy to adapt if SDK changes)
93
+ # -------------------------
94
+ def call_gemini(prompt: str, model: str = "gemini-2.0-flash", temperature: float = 0.2, max_output_tokens: int = 600) -> str:
95
+ """
96
+ Use the google.generativeai SDK to generate text.
97
+ NOTE: SDK function names vary between versions. If this exact call fails for your installed SDK,
98
+ replace the body with the appropriate `genai` call. Example alternatives appear in comments below.
99
+ """
100
+ # Example - the SDK exposes a .generate or .chat.create in different versions.
101
+ # Option A (common pattern):
102
+ resp = genai.generate(
103
+ model=model,
104
+ temperature=temperature,
105
+ max_output_tokens=max_output_tokens,
106
+ # The SDK may expect a 'prompt' string or a structured 'input' object.
107
+ # Using plain text input here.
108
+ input=prompt
109
+ )
110
+ # Inspect resp and adapt as needed. Commonly the text is in resp.output[0].content[0].text
111
+ # Defensive extraction:
112
+ try:
113
+ # Newer SDKs put text in resp.candidates[0].output[0].content[0].text
114
+ if hasattr(resp, "candidates"):
115
+ return resp.candidates[0].output[0].content[0].text
116
+ # Older/newer might have resp.output_text
117
+ if hasattr(resp, "output_text"):
118
+ return resp.output_text
119
+ # Generic fallback: str(resp)
120
+ return str(resp)
121
+ except Exception:
122
+ return str(resp)
123
+
124
+ # -------------------------
125
+ # Load document and build index (once)
126
+ # -------------------------
127
+ DOC_PATH = "pdf_extracted_text.txt"
128
+ raw_text = load_text(DOC_PATH)
129
+ chunks = simple_recursive_split(raw_text, chunk_size=1000, overlap=500)
130
+
131
+ embed_model = load_embedding_model()
132
+ index, embeddings = build_faiss_index(chunks, embed_model)
133
+
134
+ # -------------------------
135
+ # Streamlit UI and chat logic
136
+ # -------------------------
137
+ st.title("🤖 SystemVerilog Documentation Chatbot (Modern LangChain-style)")
138
+
139
+ if "chat_history" not in st.session_state:
140
+ st.session_state.chat_history = []
141
+
142
+ user_query = st.chat_input("Enter your SystemVerilog question...")
143
+
144
+ if user_query:
145
+ # retrieve
146
+ k = 6
147
+ retrieved = retrieve_top_k(user_query, k, index, embed_model, chunks)
148
+ context_text = "\n\n".join([f"Chunk {idx} (score {score:.3f}):\n{txt}" for idx, score, txt in retrieved])
149
+
150
+ # Compose system prompt using LangChain message classes for clarity (but we send plain text to Gemini)
151
+ system_msg = SystemMessage(content=(
152
+ "You are an expert SystemVerilog Verification Engineer and technical educator. "
153
+ "Answer the user's question using ONLY the provided context. If the context does not contain the answer, say you don't know. "
154
+ "Be concise, include examples when useful, and annotate code blocks as SystemVerilog."
155
+ ))
156
+ human_msg = HumanMessage(content=f"User question: {user_query}\n\nUse the following context:\n{context_text}")
157
+
158
+ # Build the final prompt to send to Gemini
159
+ prompt = f"""SYSTEM:
160
+ {system_msg.content}
161
+
162
+ CONTEXT:
163
+ {context_text}
164
+
165
+ USER QUESTION:
166
+ {user_query}
167
+
168
+ INSTRUCTIONS:
169
+ - Use only the CONTEXT above for factual content.
170
+ - Provide SystemVerilog code blocks where appropriate, label them as `systemverilog`.
171
+ - If the context lacks the answer, explicitly say: "I don't have enough information in the document to answer that."
172
+ """
173
+
174
+ # Call Gemini (via the wrapper)
175
+ gen_text = call_gemini(prompt, model="gemini-2.0-flash", temperature=0.3, max_output_tokens=700)
176
+
177
+ # Save to chat
178
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
179
+ st.session_state.chat_history.append({"role": "assistant", "content": gen_text})
180
+
181
+ # Render chat
182
+ for msg in st.session_state.chat_history:
183
+ if msg["role"] == "user":
184
+ with st.chat_message("user"):
185
+ st.markdown(msg["content"])
186
+ else:
187
+ with st.chat_message("assistant"):
188
+ st.markdown(msg["content"])