import streamlit as st import fitz # PyMuPDF for PDF extraction import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import hashlib from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import FAISS from langchain.embeddings import OllamaEmbeddings # ========================== LOAD FINE-TUNED MODEL ========================== # MODEL_PATH = "./fine_tuned_tinyllama_tax" # Change to your actual model path tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.float16, device_map="auto" ) tax_llm = pipeline("text-generation", model=model, tokenizer=tokenizer) # ========================== SESSION STATE INITIALIZATION ========================== # if "legal_knowledge_base" not in st.session_state: st.session_state.legal_knowledge_base = "" if "vector_db" not in st.session_state: st.session_state.vector_db = None if "summary" not in st.session_state: st.session_state.summary = "" if "answer" not in st.session_state: st.session_state.answer = "" # ========================== HELPER FUNCTIONS ========================== # def compute_file_hash(file): """Computes SHA-256 hash of the uploaded file to track changes.""" hasher = hashlib.sha256() hasher.update(file.read()) file.seek(0) # Reset file pointer return hasher.hexdigest() def extract_text_from_pdf(pdf_file): """Extracts text from a PDF using PyMuPDF (fitz).""" doc = fitz.open(stream=pdf_file.read(), filetype="pdf") pdf_file.seek(0) # Reset pointer text = "\n".join([page.get_text("text") for page in doc]) return text.strip() if text.strip() else "No extractable text found in PDF." def summarize_text(text): """Summarizes tax policy documents using fine-tuned AI.""" prompt = f"Summarize this tax policy document concisely:\n{text}" summary = tax_llm(prompt, max_length=200, do_sample=True)[0]["generated_text"] return summary def create_vector_db(): """Creates a searchable vector database from extracted legal documents.""" text = st.session_state.legal_knowledge_base if not text: return None text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=150) texts = text_splitter.split_text(text) embeddings = OllamaEmbeddings(model="llama3:8b") return FAISS.from_texts(texts, embeddings) def retrieve_relevant_text(query, vector_db): """Fetches relevant legal sections from the document.""" if not vector_db: return "No document uploaded." docs = vector_db.similarity_search(query, k=5) retrieved_text = "\n".join([doc.page_content for doc in docs]) return retrieved_text def compute_tax_details(query): """Extracts income & tax rate and calculates tax.""" import re income_match = re.search(r"₹?(\d[\d,]*)", query.replace(",", "")) tax_rate_match = re.search(r"(\d+)%", query) if income_match and tax_rate_match: income = float(income_match.group(1).replace(",", "")) tax_rate = float(tax_rate_match.group(1)) computed_tax = round(income * (tax_rate / 100), 2) return f"Based on an income of ₹{income:,.2f} and a tax rate of {tax_rate}%, the tax is **₹{computed_tax:,.2f}.**" return None def answer_user_query(query): """Answers tax-related queries using the fine-tuned model.""" tax_computation_result = compute_tax_details(query) if tax_computation_result: st.session_state.answer = tax_computation_result return if not st.session_state.vector_db: st.error("Please upload a document first.") return retrieved_text = retrieve_relevant_text(query, st.session_state.vector_db) prompt = f""" You are an AI tax expert. Use legal knowledge and tax calculations to answer. Context: {retrieved_text} User Query: {query} Response: """ response = tax_llm(prompt, max_length=300, do_sample=True)[0]["generated_text"] st.session_state.answer = response # ========================== STREAMLIT UI ========================== # def main(): st.title("📜 AI Legal Tax Assistant") uploaded_file = st.file_uploader("📄 Upload Tax Policy PDF", type=["pdf"]) if uploaded_file: with st.spinner("Extracting text..."): extracted_text = extract_text_from_pdf(uploaded_file) st.session_state.legal_knowledge_base = extracted_text st.success("Document Uploaded!") with st.spinner("Generating summary..."): st.session_state.summary = summarize_text(extracted_text) st.subheader("📄 Document Summary:") st.text_area("", st.session_state.summary, height=250) with st.spinner("Indexing document..."): st.session_state.vector_db = create_vector_db() st.success("Document indexed! Ask questions now.") st.subheader("💬 Ask Questions:") user_query = st.text_input("Enter your question:") if st.button("Ask") and user_query.strip(): with st.spinner("Processing..."): answer_user_query(user_query) if st.session_state.answer: st.markdown("### 🤖 AI Response:") st.success(st.session_state.answer) if __name__ == "__main__": main()