File size: 5,532 Bytes
4a73579
 
2cb9c2c
 
 
4a73579
 
 
2cb9c2c
 
 
 
 
 
 
 
 
 
 
 
 
4a73579
 
 
 
 
 
 
 
 
2cb9c2c
 
4a73579
 
 
 
2cb9c2c
4a73579
 
2cb9c2c
4a73579
 
 
2cb9c2c
 
 
 
 
4a73579
 
2cb9c2c
 
 
4a73579
 
 
2cb9c2c
4a73579
 
 
2cb9c2c
4a73579
 
2cb9c2c
4a73579
 
 
2cb9c2c
4a73579
2cb9c2c
4a73579
 
 
 
 
 
2cb9c2c
4a73579
 
 
 
 
 
 
 
 
 
2cb9c2c
4a73579
 
 
 
2cb9c2c
4a73579
 
 
 
 
 
 
 
 
2cb9c2c
4a73579
2cb9c2c
 
4a73579
2cb9c2c
 
4a73579
2cb9c2c
 
 
 
 
 
 
 
4a73579
2cb9c2c
4a73579
 
 
 
2cb9c2c
4a73579
 
2cb9c2c
 
 
 
 
 
 
 
 
 
 
 
 
4a73579
 
2cb9c2c
4a73579
2cb9c2c
 
 
4a73579
 
 
 
 
 
2cb9c2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import streamlit as st
import fitz  # PyMuPDF for PDF extraction
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import hashlib
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OllamaEmbeddings  

# ========================== LOAD FINE-TUNED MODEL ========================== #

MODEL_PATH = "./fine_tuned_tinyllama_tax"  # Change to your actual model path
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto"
)

tax_llm = pipeline("text-generation", model=model, tokenizer=tokenizer)

# ========================== SESSION STATE INITIALIZATION ========================== #

if "legal_knowledge_base" not in st.session_state:
    st.session_state.legal_knowledge_base = ""
if "vector_db" not in st.session_state:
    st.session_state.vector_db = None
if "summary" not in st.session_state:
    st.session_state.summary = ""
if "answer" not in st.session_state:
    st.session_state.answer = ""

# ========================== HELPER FUNCTIONS ========================== #

def compute_file_hash(file):
    """Computes SHA-256 hash of the uploaded file to track changes."""
    hasher = hashlib.sha256()
    hasher.update(file.read())
    file.seek(0)  # Reset file pointer
    return hasher.hexdigest()

def extract_text_from_pdf(pdf_file):
    """Extracts text from a PDF using PyMuPDF (fitz)."""
    doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
    pdf_file.seek(0)  # Reset pointer
    text = "\n".join([page.get_text("text") for page in doc])
    return text.strip() if text.strip() else "No extractable text found in PDF."

def summarize_text(text):
    """Summarizes tax policy documents using fine-tuned AI."""
    prompt = f"Summarize this tax policy document concisely:\n{text}"
    summary = tax_llm(prompt, max_length=200, do_sample=True)[0]["generated_text"]
    return summary

def create_vector_db():
    """Creates a searchable vector database from extracted legal documents."""
    text = st.session_state.legal_knowledge_base
    if not text:
        return None

    text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150)
    texts = text_splitter.split_text(text)
    embeddings = OllamaEmbeddings(model="llama3:8b")
    return FAISS.from_texts(texts, embeddings)

def retrieve_relevant_text(query, vector_db):
    """Fetches relevant legal sections from the document."""
    if not vector_db:
        return "No document uploaded."
    
    docs = vector_db.similarity_search(query, k=5)
    retrieved_text = "\n".join([doc.page_content for doc in docs])
    return retrieved_text

def compute_tax_details(query):
    """Extracts income & tax rate and calculates tax."""
    import re

    income_match = re.search(r"₹?(\d[\d,]*)", query.replace(",", ""))
    tax_rate_match = re.search(r"(\d+)%", query)

    if income_match and tax_rate_match:
        income = float(income_match.group(1).replace(",", ""))
        tax_rate = float(tax_rate_match.group(1))

        computed_tax = round(income * (tax_rate / 100), 2)
        return f"Based on an income of ₹{income:,.2f} and a tax rate of {tax_rate}%, the tax is **₹{computed_tax:,.2f}.**"

    return None

def answer_user_query(query):
    """Answers tax-related queries using the fine-tuned model."""
    tax_computation_result = compute_tax_details(query)

    if tax_computation_result:
        st.session_state.answer = tax_computation_result
        return

    if not st.session_state.vector_db:
        st.error("Please upload a document first.")
        return

    retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db)
    prompt = f"""

    You are an AI tax expert. Use legal knowledge and tax calculations to answer.



    Context:

    {retrieved_text}



    User Query:

    {query}



    Response:

    """

    response = tax_llm(prompt, max_length=300, do_sample=True)[0]["generated_text"]
    st.session_state.answer = response

# ========================== STREAMLIT UI ========================== #

def main():
    st.title("📜 AI Legal Tax Assistant")

    uploaded_file = st.file_uploader("📄 Upload Tax Policy PDF", type=["pdf"])

    if uploaded_file:
        with st.spinner("Extracting text..."):
            extracted_text = extract_text_from_pdf(uploaded_file)
            st.session_state.legal_knowledge_base = extracted_text
            st.success("Document Uploaded!")

        with st.spinner("Generating summary..."):
            st.session_state.summary = summarize_text(extracted_text)
            st.subheader("📄 Document Summary:")
            st.text_area("", st.session_state.summary, height=250)

        with st.spinner("Indexing document..."):
            st.session_state.vector_db = create_vector_db()
            st.success("Document indexed! Ask questions now.")

    st.subheader("💬 Ask Questions:")
    user_query = st.text_input("Enter your question:")

    if st.button("Ask") and user_query.strip():
        with st.spinner("Processing..."):
            answer_user_query(user_query)

    if st.session_state.answer:
        st.markdown("### 🤖 AI Response:")
        st.success(st.session_state.answer)

if __name__ == "__main__":
    main()