fragger246 commited on
Commit
2cb9c2c
·
verified ·
1 Parent(s): 4a73579

Upload 6 files

Browse files
dataset_processing.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer
3
+
4
+ # Model name
5
+ MODEL_NAME = "/falcon-7b"
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
7
+
8
+ # Load dataset
9
+ dataset = load_dataset("json", data_files="tax_train_data.json") # Replace with actual dataset
10
+
11
+ # Preprocessing function
12
+ def preprocess_function(examples):
13
+ inputs = examples["prompt"] # Get prompt text
14
+ targets = examples["response"] # Get response text
15
+
16
+ # Tokenize both inputs and targets
17
+ model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=512)
18
+ labels = tokenizer(targets, padding="max_length", truncation=True, max_length=512)
19
+
20
+ model_inputs["labels"] = labels["input_ids"] # Add labels to dataset
21
+ return model_inputs
22
+
23
+ # Apply preprocessing to dataset
24
+ processed_dataset = dataset.map(preprocess_function, batched=True)
25
+
26
+ # Save processed dataset
27
+ processed_dataset.save_to_disk("processed_dataset.json")
fine_tuned_tax ADDED
File without changes
finetune_tinyllama.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
3
+ from datasets import load_dataset
4
+
5
+ # Model name
6
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7
+
8
+ # Load tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_NAME,
12
+ torch_dtype=torch.float16, # Use float16 for better efficiency
13
+ device_map="auto" # Use GPU if available
14
+ )
15
+
16
+ # Load dataset from JSON file
17
+ dataset = load_dataset("json", data_files="processed_dataset.json")
18
+
19
+ # Tokenization function
20
+ def tokenize_function(examples):
21
+ return tokenizer(examples["prompt"], examples["response"], padding="max_length", truncation=True)
22
+
23
+ # Apply tokenization
24
+ dataset = dataset.map(tokenize_function, batched=True)
25
+ dataset = dataset.remove_columns(["prompt", "response"]) # Keep only tokenized data
26
+
27
+ # Data collator (for batching and padding)
28
+ data_collator = DataCollatorForSeq2Seq(
29
+ tokenizer=tokenizer,
30
+ model=model,
31
+ padding=True,
32
+ return_tensors="pt"
33
+ )
34
+
35
+ # Training arguments
36
+ training_args = TrainingArguments(
37
+ output_dir="./results",
38
+ num_train_epochs=3,
39
+ per_device_train_batch_size=4,
40
+ per_device_eval_batch_size=4,
41
+ save_steps=10_000,
42
+ save_total_limit=2,
43
+ logging_dir="./logs",
44
+ logging_steps=200,
45
+ remove_unused_columns=False, # Ensure tokenized data isn't removed
46
+ fp16=True, # Enable mixed precision if using GPU
47
+ )
48
+
49
+ # Trainer setup
50
+ trainer = Trainer(
51
+ model=model,
52
+ args=training_args,
53
+ train_dataset=dataset["train"],
54
+ data_collator=data_collator,
55
+ )
56
+
57
+ # Start training
58
+ trainer.train()
processed_dataset.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "prompt": "Calculate tax for ₹10,00,000 at 30% rate.",
4
+ "response": "The tax is ₹3,00,000."
5
+ },
6
+ {
7
+ "prompt": "Explain Section 80C of the Income Tax Act.",
8
+ "response": "Section 80C allows deductions up to ₹1,50,000 on PPF, EPF, and life insurance."
9
+ },
10
+ {
11
+ "prompt": "What is the tax on ₹8,50,000 with a 20% slab?",
12
+ "response": "The tax is ₹1,70,000."
13
+ },
14
+ {
15
+ "prompt": "How does the new tax regime differ from the old tax regime?",
16
+ "response": "The new tax regime has lower tax rates but fewer deductions, while the old regime allows more exemptions."
17
+ },
18
+ {
19
+ "prompt": "What is the exemption limit under the new tax regime for FY 2023-24?",
20
+ "response": "The exemption limit is ₹3,00,000 under the new tax regime."
21
+ },
22
+ {
23
+ "prompt": "Is HRA exempt from income tax?",
24
+ "response": "Yes, House Rent Allowance (HRA) is exempt under Section 10(13A) based on salary, rent paid, and location."
25
+ },
26
+ {
27
+ "prompt": "How to save tax under Section 80D?",
28
+ "response": "Section 80D allows deductions on health insurance premiums up to ₹25,000 (₹50,000 for senior citizens)."
29
+ },
30
+ {
31
+ "prompt": "What is the capital gains tax on the sale of property?",
32
+ "response": "Long-term capital gains (LTCG) on property are taxed at 20% with indexation, while short-term gains are taxed as per the income slab."
33
+ },
34
+ {
35
+ "prompt": "Can I claim deductions on home loan interest?",
36
+ "response": "Yes, under Section 24(b), you can claim up to ₹2,00,000 per year on home loan interest."
37
+ },
38
+ {
39
+ "prompt": "What is the GST rate on restaurant bills?",
40
+ "response": "The GST rate on restaurant bills is 5% for non-AC restaurants and 18% for AC restaurants."
41
+ },
42
+ {
43
+ "prompt": "What is TDS and when is it deducted?",
44
+ "response": "Tax Deducted at Source (TDS) is deducted by the payer on salaries, rent, and interest payments as per prescribed rates."
45
+ },
46
+ {
47
+ "prompt": "How can NRIs save tax in India?",
48
+ "response": "NRIs can save tax through DTAA benefits, NRE accounts, and exemptions on certain investments."
49
+ },
50
+ {
51
+ "prompt": "What is the corporate tax rate in India?",
52
+ "response": "The corporate tax rate is 22% for domestic companies under the new regime and 30% under the old regime."
53
+ },
54
+ {
55
+ "prompt": "Are agricultural incomes taxable?",
56
+ "response": "No, agricultural income is exempt from tax under Section 10(1)."
57
+ },
58
+ {
59
+ "prompt": "What are the penalties for late ITR filing?",
60
+ "response": "A late fee of ₹5,000 applies if filed after the due date, and ₹10,000 for income above ₹5 lakh."
61
+ },
62
+ {
63
+ "prompt": "Explain Section 80G of the Income Tax Act.",
64
+ "response": "Section 80G allows deductions on donations made to eligible charities, ranging from 50% to 100% of the donation."
65
+ },
66
+ {
67
+ "prompt": "What is Advance Tax, and who needs to pay it?",
68
+ "response": "Advance Tax is payable if total tax liability exceeds ₹10,000 in a financial year and is paid in installments."
69
+ },
70
+ {
71
+ "prompt": "What is the basic exemption limit for senior citizens?",
72
+ "response": "The exemption limit for senior citizens (60-80 years) is ₹3,00,000 and ₹5,00,000 for super senior citizens."
73
+ },
74
+ {
75
+ "prompt": "How does tax loss harvesting work?",
76
+ "response": "Tax loss harvesting helps offset capital gains by selling loss-making stocks to reduce taxable income."
77
+ },
78
+ {
79
+ "prompt": "What is the standard deduction for salaried employees?",
80
+ "response": "A standard deduction of ₹50,000 is available for salaried and pensioned individuals."
81
+ }
82
+ ]
tax_train_data.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "prompt": "Calculate tax for ₹10,00,000 at 30% rate.",
4
+ "response": "The tax is ₹3,00,000."
5
+ },
6
+ {
7
+ "prompt": "Explain Section 80C of the Income Tax Act.",
8
+ "response": "Section 80C allows deductions up to ₹1,50,000 on PPF, EPF, and life insurance."
9
+ },
10
+ {
11
+ "prompt": "What is the tax on ₹8,50,000 with a 20% slab?",
12
+ "response": "The tax is ₹1,70,000."
13
+ },
14
+ {
15
+ "prompt": "How does the new tax regime differ from the old tax regime?",
16
+ "response": "The new tax regime has lower tax rates but fewer deductions, while the old regime allows more exemptions."
17
+ },
18
+ {
19
+ "prompt": "What is the exemption limit under the new tax regime for FY 2023-24?",
20
+ "response": "The exemption limit is ₹3,00,000 under the new tax regime."
21
+ },
22
+ {
23
+ "prompt": "Is HRA exempt from income tax?",
24
+ "response": "Yes, House Rent Allowance (HRA) is exempt under Section 10(13A) based on salary, rent paid, and location."
25
+ },
26
+ {
27
+ "prompt": "How to save tax under Section 80D?",
28
+ "response": "Section 80D allows deductions on health insurance premiums up to ₹25,000 (₹50,000 for senior citizens)."
29
+ },
30
+ {
31
+ "prompt": "What is the capital gains tax on the sale of property?",
32
+ "response": "Long-term capital gains (LTCG) on property are taxed at 20% with indexation, while short-term gains are taxed as per the income slab."
33
+ },
34
+ {
35
+ "prompt": "Can I claim deductions on home loan interest?",
36
+ "response": "Yes, under Section 24(b), you can claim up to ₹2,00,000 per year on home loan interest."
37
+ },
38
+ {
39
+ "prompt": "What is the GST rate on restaurant bills?",
40
+ "response": "The GST rate on restaurant bills is 5% for non-AC restaurants and 18% for AC restaurants."
41
+ },
42
+ {
43
+ "prompt": "What is TDS and when is it deducted?",
44
+ "response": "Tax Deducted at Source (TDS) is deducted by the payer on salaries, rent, and interest payments as per prescribed rates."
45
+ },
46
+ {
47
+ "prompt": "How can NRIs save tax in India?",
48
+ "response": "NRIs can save tax through DTAA benefits, NRE accounts, and exemptions on certain investments."
49
+ },
50
+ {
51
+ "prompt": "What is the corporate tax rate in India?",
52
+ "response": "The corporate tax rate is 22% for domestic companies under the new regime and 30% under the old regime."
53
+ },
54
+ {
55
+ "prompt": "Are agricultural incomes taxable?",
56
+ "response": "No, agricultural income is exempt from tax under Section 10(1)."
57
+ },
58
+ {
59
+ "prompt": "What are the penalties for late ITR filing?",
60
+ "response": "A late fee of ₹5,000 applies if filed after the due date, and ₹10,000 for income above ₹5 lakh."
61
+ },
62
+ {
63
+ "prompt": "Explain Section 80G of the Income Tax Act.",
64
+ "response": "Section 80G allows deductions on donations made to eligible charities, ranging from 50% to 100% of the donation."
65
+ },
66
+ {
67
+ "prompt": "What is Advance Tax, and who needs to pay it?",
68
+ "response": "Advance Tax is payable if total tax liability exceeds ₹10,000 in a financial year and is paid in installments."
69
+ },
70
+ {
71
+ "prompt": "What is the basic exemption limit for senior citizens?",
72
+ "response": "The exemption limit for senior citizens (60-80 years) is ₹3,00,000 and ₹5,00,000 for super senior citizens."
73
+ },
74
+ {
75
+ "prompt": "How does tax loss harvesting work?",
76
+ "response": "Tax loss harvesting helps offset capital gains by selling loss-making stocks to reduce taxable income."
77
+ },
78
+ {
79
+ "prompt": "What is the standard deduction for salaried employees?",
80
+ "response": "A standard deduction of ₹50,000 is available for salaried and pensioned individuals."
81
+ }
82
+ ]
taxagent.py CHANGED
@@ -1,91 +1,82 @@
1
  import streamlit as st
2
  import fitz # PyMuPDF for PDF extraction
3
- from langchain_community.llms import Ollama
4
- from langchain.chains import LLMChain
5
- from langchain.prompts import PromptTemplate
6
- from langchain.memory import ConversationBufferMemory
7
  from langchain.text_splitter import CharacterTextSplitter
8
  from langchain.vectorstores import FAISS
9
  from langchain.embeddings import OllamaEmbeddings
10
- import hashlib
11
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # ========================== SESSION STATE INITIALIZATION ========================== #
14
 
15
- if "memory" not in st.session_state:
16
- st.session_state.memory = ConversationBufferMemory()
17
- if "chat_history" not in st.session_state:
18
- st.session_state.chat_history = []
19
  if "legal_knowledge_base" not in st.session_state:
20
  st.session_state.legal_knowledge_base = ""
21
- if "user_query" not in st.session_state:
22
- st.session_state.user_query = ""
23
- if "answer" not in st.session_state:
24
- st.session_state.answer = ""
25
  if "vector_db" not in st.session_state:
26
  st.session_state.vector_db = None
27
  if "summary" not in st.session_state:
28
  st.session_state.summary = ""
29
- if "doc_hash" not in st.session_state:
30
- st.session_state.doc_hash = ""
31
 
32
  # ========================== HELPER FUNCTIONS ========================== #
33
 
34
  def compute_file_hash(file):
35
- """Computes SHA-256 hash of the uploaded file to check for changes."""
36
  hasher = hashlib.sha256()
37
  hasher.update(file.read())
38
- file.seek(0) # Reset file pointer after reading
39
  return hasher.hexdigest()
40
 
41
  def extract_text_from_pdf(pdf_file):
42
- """Extracts text from a PDF file using PyMuPDF (fitz)."""
43
- try:
44
- doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
45
- pdf_file.seek(0) # Reset file pointer
46
- text = "\n".join([page.get_text("text") for page in doc])
47
- return text.strip() if text.strip() else "No extractable text found in PDF."
48
- except Exception as e:
49
- return f"Error reading PDF: {e}"
50
 
51
  def summarize_text(text):
52
- """Summarizes the extracted legal document using AI."""
53
- llm = Ollama(model="llama3:8b")
54
- prompt = PromptTemplate(
55
- input_variables=["text"],
56
- template="Summarize this tax policy document concisely:\n{text}"
57
- )
58
- chain = LLMChain(llm=llm, prompt=prompt)
59
- summary = chain.run(text=text)
60
  return summary
61
 
62
  def create_vector_db():
63
- """Converts the extracted legal document into searchable vector embeddings."""
64
  text = st.session_state.legal_knowledge_base
65
  if not text:
66
  return None
67
-
68
  text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150)
69
  texts = text_splitter.split_text(text)
70
- embeddings = OllamaEmbeddings(model="llama3")
71
  return FAISS.from_texts(texts, embeddings)
72
 
73
  def retrieve_relevant_text(query, vector_db):
74
- """Fetches relevant sections from the document based on the user's query."""
75
  if not vector_db:
76
- return "No legal document uploaded."
77
 
78
  docs = vector_db.similarity_search(query, k=5)
79
  retrieved_text = "\n".join([doc.page_content for doc in docs])
80
  return retrieved_text
81
 
82
- # ========================== AI TAX COMPUTATION & REASONING ========================== #
83
-
84
  def compute_tax_details(query):
85
- """Processes user queries related to tax calculations."""
86
  import re
87
 
88
- # Extract income & tax rate from query
89
  income_match = re.search(r"₹?(\d[\d,]*)", query.replace(",", ""))
90
  tax_rate_match = re.search(r"(\d+)%", query)
91
 
@@ -94,77 +85,70 @@ def compute_tax_details(query):
94
  tax_rate = float(tax_rate_match.group(1))
95
 
96
  computed_tax = round(income * (tax_rate / 100), 2)
97
- return f"Based on an income of ₹{income:,.2f} and a tax rate of {tax_rate}%, the calculated tax is **₹{computed_tax:,.2f}.**"
98
 
99
  return None
100
 
101
  def answer_user_query(query):
102
- """Answers user queries using retrieved legal text & tax calculations."""
103
  tax_computation_result = compute_tax_details(query)
104
 
105
  if tax_computation_result:
106
  st.session_state.answer = tax_computation_result
107
- st.session_state.chat_history.append({"query": query, "response": st.session_state.answer})
108
  return
109
 
110
  if not st.session_state.vector_db:
111
  st.error("Please upload a document first.")
112
  return
113
-
114
- llm = Ollama(model="llama3:8b")
115
  retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db)
116
- combined_context = f"Laws:\n{retrieved_text}\n\nUser Query:\n{query}"
 
117
 
118
- prompt_template = PromptTemplate(
119
- input_variables=["input_text"],
120
- template="""
121
- You are an AI legal expert specializing in tax and finance. Answer the user's query using legal context & real-world tax computation.
122
 
123
- Context:
124
- {input_text}
125
- """
126
- )
127
-
128
- chain = LLMChain(llm=llm, prompt=prompt_template, memory=st.session_state.memory)
129
- st.session_state.answer = chain.run(input_text=combined_context)
130
- st.session_state.chat_history.append({"query": query, "response": st.session_state.answer})
131
 
132
- # ========================== MAIN STREAMLIT APP ========================== #
133
 
134
  def main():
135
  st.title("📜 AI Legal Tax Assistant")
136
 
137
- uploaded_file = st.file_uploader("📄 Upload Policy PDF", type=["pdf"])
138
 
139
  if uploaded_file:
140
- file_hash = compute_file_hash(uploaded_file)
141
-
142
- if file_hash != st.session_state.doc_hash:
143
- st.session_state.doc_hash = file_hash
144
- with st.spinner("Extracting text..."):
145
- extracted_text = extract_text_from_pdf(uploaded_file)
146
- st.session_state.legal_knowledge_base = extracted_text
147
- st.success("Policy Document Uploaded & Stored!")
148
-
149
- with st.spinner("Generating summary..."):
150
- st.session_state.summary = summarize_text(extracted_text)
151
- st.subheader("📄 Document Summary:")
152
- st.text_area("", st.session_state.summary, height=250)
153
-
154
- with st.spinner("Indexing document for Q&A..."):
155
- st.session_state.vector_db = create_vector_db()
156
- st.success("Document indexed! Now you can ask questions.")
157
 
158
  st.subheader("💬 Ask Questions:")
159
- st.session_state.user_query = st.text_input("Enter your question:")
160
 
161
- if st.button("Ask") and st.session_state.user_query.strip():
162
- with st.spinner("Thinking..."):
163
- answer_user_query(st.session_state.user_query)
164
 
165
  if st.session_state.answer:
166
  st.markdown("### 🤖 AI Response:")
167
  st.success(st.session_state.answer)
168
 
169
  if __name__ == "__main__":
170
- main()
 
1
  import streamlit as st
2
  import fitz # PyMuPDF for PDF extraction
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+ import hashlib
 
6
  from langchain.text_splitter import CharacterTextSplitter
7
  from langchain.vectorstores import FAISS
8
  from langchain.embeddings import OllamaEmbeddings
9
+
10
+ # ========================== LOAD FINE-TUNED MODEL ========================== #
11
+
12
+ MODEL_PATH = "./fine_tuned_tinyllama_tax" # Change to your actual model path
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_PATH,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto"
19
+ )
20
+
21
+ tax_llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
22
 
23
  # ========================== SESSION STATE INITIALIZATION ========================== #
24
 
 
 
 
 
25
  if "legal_knowledge_base" not in st.session_state:
26
  st.session_state.legal_knowledge_base = ""
 
 
 
 
27
  if "vector_db" not in st.session_state:
28
  st.session_state.vector_db = None
29
  if "summary" not in st.session_state:
30
  st.session_state.summary = ""
31
+ if "answer" not in st.session_state:
32
+ st.session_state.answer = ""
33
 
34
  # ========================== HELPER FUNCTIONS ========================== #
35
 
36
  def compute_file_hash(file):
37
+ """Computes SHA-256 hash of the uploaded file to track changes."""
38
  hasher = hashlib.sha256()
39
  hasher.update(file.read())
40
+ file.seek(0) # Reset file pointer
41
  return hasher.hexdigest()
42
 
43
  def extract_text_from_pdf(pdf_file):
44
+ """Extracts text from a PDF using PyMuPDF (fitz)."""
45
+ doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
46
+ pdf_file.seek(0) # Reset pointer
47
+ text = "\n".join([page.get_text("text") for page in doc])
48
+ return text.strip() if text.strip() else "No extractable text found in PDF."
 
 
 
49
 
50
  def summarize_text(text):
51
+ """Summarizes tax policy documents using fine-tuned AI."""
52
+ prompt = f"Summarize this tax policy document concisely:\n{text}"
53
+ summary = tax_llm(prompt, max_length=200, do_sample=True)[0]["generated_text"]
 
 
 
 
 
54
  return summary
55
 
56
  def create_vector_db():
57
+ """Creates a searchable vector database from extracted legal documents."""
58
  text = st.session_state.legal_knowledge_base
59
  if not text:
60
  return None
61
+
62
  text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150)
63
  texts = text_splitter.split_text(text)
64
+ embeddings = OllamaEmbeddings(model="llama3:8b")
65
  return FAISS.from_texts(texts, embeddings)
66
 
67
  def retrieve_relevant_text(query, vector_db):
68
+ """Fetches relevant legal sections from the document."""
69
  if not vector_db:
70
+ return "No document uploaded."
71
 
72
  docs = vector_db.similarity_search(query, k=5)
73
  retrieved_text = "\n".join([doc.page_content for doc in docs])
74
  return retrieved_text
75
 
 
 
76
  def compute_tax_details(query):
77
+ """Extracts income & tax rate and calculates tax."""
78
  import re
79
 
 
80
  income_match = re.search(r"₹?(\d[\d,]*)", query.replace(",", ""))
81
  tax_rate_match = re.search(r"(\d+)%", query)
82
 
 
85
  tax_rate = float(tax_rate_match.group(1))
86
 
87
  computed_tax = round(income * (tax_rate / 100), 2)
88
+ return f"Based on an income of ₹{income:,.2f} and a tax rate of {tax_rate}%, the tax is **₹{computed_tax:,.2f}.**"
89
 
90
  return None
91
 
92
  def answer_user_query(query):
93
+ """Answers tax-related queries using the fine-tuned model."""
94
  tax_computation_result = compute_tax_details(query)
95
 
96
  if tax_computation_result:
97
  st.session_state.answer = tax_computation_result
 
98
  return
99
 
100
  if not st.session_state.vector_db:
101
  st.error("Please upload a document first.")
102
  return
103
+
 
104
  retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db)
105
+ prompt = f"""
106
+ You are an AI tax expert. Use legal knowledge and tax calculations to answer.
107
 
108
+ Context:
109
+ {retrieved_text}
 
 
110
 
111
+ User Query:
112
+ {query}
113
+
114
+ Response:
115
+ """
116
+
117
+ response = tax_llm(prompt, max_length=300, do_sample=True)[0]["generated_text"]
118
+ st.session_state.answer = response
119
 
120
+ # ========================== STREAMLIT UI ========================== #
121
 
122
  def main():
123
  st.title("📜 AI Legal Tax Assistant")
124
 
125
+ uploaded_file = st.file_uploader("📄 Upload Tax Policy PDF", type=["pdf"])
126
 
127
  if uploaded_file:
128
+ with st.spinner("Extracting text..."):
129
+ extracted_text = extract_text_from_pdf(uploaded_file)
130
+ st.session_state.legal_knowledge_base = extracted_text
131
+ st.success("Document Uploaded!")
132
+
133
+ with st.spinner("Generating summary..."):
134
+ st.session_state.summary = summarize_text(extracted_text)
135
+ st.subheader("📄 Document Summary:")
136
+ st.text_area("", st.session_state.summary, height=250)
137
+
138
+ with st.spinner("Indexing document..."):
139
+ st.session_state.vector_db = create_vector_db()
140
+ st.success("Document indexed! Ask questions now.")
 
 
 
 
141
 
142
  st.subheader("💬 Ask Questions:")
143
+ user_query = st.text_input("Enter your question:")
144
 
145
+ if st.button("Ask") and user_query.strip():
146
+ with st.spinner("Processing..."):
147
+ answer_user_query(user_query)
148
 
149
  if st.session_state.answer:
150
  st.markdown("### 🤖 AI Response:")
151
  st.success(st.session_state.answer)
152
 
153
  if __name__ == "__main__":
154
+ main()