Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| st.set_page_config( | |
| page_title="Section 8.1 Legal Assistant", | |
| page_icon="⚖️", | |
| layout="centered", | |
| ) | |
| st.title("⚖️ Section 8.1 Legal Assistant") | |
| st.markdown("**Reinforcement Fine-Tuned Model for ITAA 1997 - Section 8.1 (General Deductions)**") | |
| st.markdown("---") | |
| SYSTEM_PROMPT = """You ONLY answer questions about Section 8.1 of the Income Tax Assessment Act 1997 (General Deductions). If a question is about any other section, topic, or contains wrong details about Section 8.1, refuse or correct it. Never add information not in Section 8.1.""" | |
| MODEL_ID = "muhammadjasim12/rainforcejasim" | |
| def load_model(): | |
| """Load model once and cache it.""" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen2.5-7B-Instruct", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, MODEL_ID) | |
| model.eval() | |
| return model, tokenizer | |
| def ask(question, model, tokenizer): | |
| prompt = ( | |
| f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" | |
| f"<|im_start|>user\n{question}<|im_end|>\n" | |
| f"<|im_start|>assistant\n" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| return tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| # Load model | |
| with st.spinner("Loading model... (first time takes 2-3 minutes)"): | |
| model, tokenizer = load_model() | |
| st.success("Model loaded!") | |
| st.markdown("---") | |
| # Chat interface | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat history | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # User input | |
| user_input = st.chat_input("Ask a question about Section 8.1...") | |
| if user_input: | |
| # Show user message | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| # Get model answer | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| answer = ask(user_input, model, tokenizer) | |
| st.markdown(answer) | |
| st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("About") | |
| st.markdown(""" | |
| This model is **reinforcement fine-tuned (DPO + SFT)** | |
| exclusively on **Section 8.1** of the Income Tax | |
| Assessment Act 1997 (General Deductions). | |
| **It will:** | |
| - Answer questions about Section 8.1 | |
| - Refuse questions about other sections | |
| - Correct wrong details in questions | |
| **It will NOT:** | |
| - Answer questions outside Section 8.1 | |
| - Add information not in the section | |
| - Make up dollar amounts or rules | |
| """) | |
| st.markdown("---") | |
| st.markdown("**Model:** `muhammadjasim12/rainforcejasim`") | |
| st.markdown("**Base:** Qwen2.5-7B-Instruct") | |
| st.markdown("**Training:** DPO + SFT") | |
| st.markdown("---") | |
| st.header("Example Questions") | |
| st.markdown(""" | |
| - What is Section 8.1 about? | |
| - What does Section 8.1(1)(a) say? | |
| - Can I deduct a capital expense? | |
| - What does Section 8.2 say? | |
| - Does Section 8.1 have four subsections? | |
| - What is Division 7A? | |
| """) | |
| if st.button("Clear Chat"): | |
| st.session_state.messages = [] | |
| st.rerun() | |