section8-1-assistant / randtext.py
muhammadjasim12's picture
Upload randtext.py
8f96aec verified
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
st.set_page_config(
page_title="Section 8.1 Legal Assistant",
page_icon="⚖️",
layout="centered",
)
st.title("⚖️ Section 8.1 Legal Assistant")
st.markdown("**Reinforcement Fine-Tuned Model for ITAA 1997 - Section 8.1 (General Deductions)**")
st.markdown("---")
SYSTEM_PROMPT = """You ONLY answer questions about Section 8.1 of the Income Tax Assessment Act 1997 (General Deductions). If a question is about any other section, topic, or contains wrong details about Section 8.1, refuse or correct it. Never add information not in Section 8.1."""
MODEL_ID = "muhammadjasim12/rainforcejasim"
@st.cache_resource
def load_model():
"""Load model once and cache it."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, MODEL_ID)
model.eval()
return model, tokenizer
def ask(question, model, tokenizer):
prompt = (
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
).strip()
# Load model
with st.spinner("Loading model... (first time takes 2-3 minutes)"):
model, tokenizer = load_model()
st.success("Model loaded!")
st.markdown("---")
# Chat interface
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# User input
user_input = st.chat_input("Ask a question about Section 8.1...")
if user_input:
# Show user message
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
# Get model answer
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
answer = ask(user_input, model, tokenizer)
st.markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})
# Sidebar
with st.sidebar:
st.header("About")
st.markdown("""
This model is **reinforcement fine-tuned (DPO + SFT)**
exclusively on **Section 8.1** of the Income Tax
Assessment Act 1997 (General Deductions).
**It will:**
- Answer questions about Section 8.1
- Refuse questions about other sections
- Correct wrong details in questions
**It will NOT:**
- Answer questions outside Section 8.1
- Add information not in the section
- Make up dollar amounts or rules
""")
st.markdown("---")
st.markdown("**Model:** `muhammadjasim12/rainforcejasim`")
st.markdown("**Base:** Qwen2.5-7B-Instruct")
st.markdown("**Training:** DPO + SFT")
st.markdown("---")
st.header("Example Questions")
st.markdown("""
- What is Section 8.1 about?
- What does Section 8.1(1)(a) say?
- Can I deduct a capital expense?
- What does Section 8.2 say?
- Does Section 8.1 have four subsections?
- What is Division 7A?
""")
if st.button("Clear Chat"):
st.session_state.messages = []
st.rerun()