File size: 5,532 Bytes
4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c 4a73579 2cb9c2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import streamlit as st
import fitz # PyMuPDF for PDF extraction
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import hashlib
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OllamaEmbeddings
# ========================== LOAD FINE-TUNED MODEL ========================== #
MODEL_PATH = "./fine_tuned_tinyllama_tax" # Change to your actual model path
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float16,
device_map="auto"
)
tax_llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ========================== SESSION STATE INITIALIZATION ========================== #
if "legal_knowledge_base" not in st.session_state:
st.session_state.legal_knowledge_base = ""
if "vector_db" not in st.session_state:
st.session_state.vector_db = None
if "summary" not in st.session_state:
st.session_state.summary = ""
if "answer" not in st.session_state:
st.session_state.answer = ""
# ========================== HELPER FUNCTIONS ========================== #
def compute_file_hash(file):
"""Computes SHA-256 hash of the uploaded file to track changes."""
hasher = hashlib.sha256()
hasher.update(file.read())
file.seek(0) # Reset file pointer
return hasher.hexdigest()
def extract_text_from_pdf(pdf_file):
"""Extracts text from a PDF using PyMuPDF (fitz)."""
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
pdf_file.seek(0) # Reset pointer
text = "\n".join([page.get_text("text") for page in doc])
return text.strip() if text.strip() else "No extractable text found in PDF."
def summarize_text(text):
"""Summarizes tax policy documents using fine-tuned AI."""
prompt = f"Summarize this tax policy document concisely:\n{text}"
summary = tax_llm(prompt, max_length=200, do_sample=True)[0]["generated_text"]
return summary
def create_vector_db():
"""Creates a searchable vector database from extracted legal documents."""
text = st.session_state.legal_knowledge_base
if not text:
return None
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150)
texts = text_splitter.split_text(text)
embeddings = OllamaEmbeddings(model="llama3:8b")
return FAISS.from_texts(texts, embeddings)
def retrieve_relevant_text(query, vector_db):
"""Fetches relevant legal sections from the document."""
if not vector_db:
return "No document uploaded."
docs = vector_db.similarity_search(query, k=5)
retrieved_text = "\n".join([doc.page_content for doc in docs])
return retrieved_text
def compute_tax_details(query):
"""Extracts income & tax rate and calculates tax."""
import re
income_match = re.search(r"₹?(\d[\d,]*)", query.replace(",", ""))
tax_rate_match = re.search(r"(\d+)%", query)
if income_match and tax_rate_match:
income = float(income_match.group(1).replace(",", ""))
tax_rate = float(tax_rate_match.group(1))
computed_tax = round(income * (tax_rate / 100), 2)
return f"Based on an income of ₹{income:,.2f} and a tax rate of {tax_rate}%, the tax is **₹{computed_tax:,.2f}.**"
return None
def answer_user_query(query):
"""Answers tax-related queries using the fine-tuned model."""
tax_computation_result = compute_tax_details(query)
if tax_computation_result:
st.session_state.answer = tax_computation_result
return
if not st.session_state.vector_db:
st.error("Please upload a document first.")
return
retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db)
prompt = f"""
You are an AI tax expert. Use legal knowledge and tax calculations to answer.
Context:
{retrieved_text}
User Query:
{query}
Response:
"""
response = tax_llm(prompt, max_length=300, do_sample=True)[0]["generated_text"]
st.session_state.answer = response
# ========================== STREAMLIT UI ========================== #
def main():
st.title("📜 AI Legal Tax Assistant")
uploaded_file = st.file_uploader("📄 Upload Tax Policy PDF", type=["pdf"])
if uploaded_file:
with st.spinner("Extracting text..."):
extracted_text = extract_text_from_pdf(uploaded_file)
st.session_state.legal_knowledge_base = extracted_text
st.success("Document Uploaded!")
with st.spinner("Generating summary..."):
st.session_state.summary = summarize_text(extracted_text)
st.subheader("📄 Document Summary:")
st.text_area("", st.session_state.summary, height=250)
with st.spinner("Indexing document..."):
st.session_state.vector_db = create_vector_db()
st.success("Document indexed! Ask questions now.")
st.subheader("💬 Ask Questions:")
user_query = st.text_input("Enter your question:")
if st.button("Ask") and user_query.strip():
with st.spinner("Processing..."):
answer_user_query(user_query)
if st.session_state.answer:
st.markdown("### 🤖 AI Response:")
st.success(st.session_state.answer)
if __name__ == "__main__":
main()
|