import streamlit as st import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel st.set_page_config( page_title="Section 8.1 Legal Assistant", page_icon="⚖️", layout="centered", ) st.title("⚖️ Section 8.1 Legal Assistant") st.markdown("**Reinforcement Fine-Tuned Model for ITAA 1997 - Section 8.1 (General Deductions)**") st.markdown("---") SYSTEM_PROMPT = """You ONLY answer questions about Section 8.1 of the Income Tax Assessment Act 1997 (General Deductions). If a question is about any other section, topic, or contains wrong details about Section 8.1, refuse or correct it. Never add information not in Section 8.1.""" MODEL_ID = "muhammadjasim12/rainforcejasim" @st.cache_resource def load_model(): """Load model once and cache it.""" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2.5-7B-Instruct", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) model = PeftModel.from_pretrained(base_model, MODEL_ID) model.eval() return model, tokenizer def ask(question, model, tokenizer): prompt = ( f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" f"<|im_start|>user\n{question}<|im_end|>\n" f"<|im_start|>assistant\n" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=300, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) return tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ).strip() # Load model with st.spinner("Loading model... (first time takes 2-3 minutes)"): model, tokenizer = load_model() st.success("Model loaded!") st.markdown("---") # Chat interface if "messages" not in st.session_state: st.session_state.messages = [] # Display chat history for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # User input user_input = st.chat_input("Ask a question about Section 8.1...") if user_input: # Show user message st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.markdown(user_input) # Get model answer with st.chat_message("assistant"): with st.spinner("Thinking..."): answer = ask(user_input, model, tokenizer) st.markdown(answer) st.session_state.messages.append({"role": "assistant", "content": answer}) # Sidebar with st.sidebar: st.header("About") st.markdown(""" This model is **reinforcement fine-tuned (DPO + SFT)** exclusively on **Section 8.1** of the Income Tax Assessment Act 1997 (General Deductions). **It will:** - Answer questions about Section 8.1 - Refuse questions about other sections - Correct wrong details in questions **It will NOT:** - Answer questions outside Section 8.1 - Add information not in the section - Make up dollar amounts or rules """) st.markdown("---") st.markdown("**Model:** `muhammadjasim12/rainforcejasim`") st.markdown("**Base:** Qwen2.5-7B-Instruct") st.markdown("**Training:** DPO + SFT") st.markdown("---") st.header("Example Questions") st.markdown(""" - What is Section 8.1 about? - What does Section 8.1(1)(a) say? - Can I deduct a capital expense? - What does Section 8.2 say? - Does Section 8.1 have four subsections? - What is Division 7A? """) if st.button("Clear Chat"): st.session_state.messages = [] st.rerun()