muhammadjasim12 commited on
Commit
8f96aec
·
verified ·
1 Parent(s): 493843a

Upload randtext.py

Browse files
Files changed (1) hide show
  1. randtext.py +126 -0
randtext.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
+
6
+ st.set_page_config(
7
+ page_title="Section 8.1 Legal Assistant",
8
+ page_icon="⚖️",
9
+ layout="centered",
10
+ )
11
+
12
+ st.title("⚖️ Section 8.1 Legal Assistant")
13
+ st.markdown("**Reinforcement Fine-Tuned Model for ITAA 1997 - Section 8.1 (General Deductions)**")
14
+ st.markdown("---")
15
+
16
+ SYSTEM_PROMPT = """You ONLY answer questions about Section 8.1 of the Income Tax Assessment Act 1997 (General Deductions). If a question is about any other section, topic, or contains wrong details about Section 8.1, refuse or correct it. Never add information not in Section 8.1."""
17
+
18
+ MODEL_ID = "muhammadjasim12/rainforcejasim"
19
+
20
+
21
+ @st.cache_resource
22
+ def load_model():
23
+ """Load model once and cache it."""
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
25
+ base_model = AutoModelForCausalLM.from_pretrained(
26
+ "Qwen/Qwen2.5-7B-Instruct",
27
+ torch_dtype=torch.float16,
28
+ device_map="auto",
29
+ trust_remote_code=True,
30
+ )
31
+ model = PeftModel.from_pretrained(base_model, MODEL_ID)
32
+ model.eval()
33
+ return model, tokenizer
34
+
35
+
36
+ def ask(question, model, tokenizer):
37
+ prompt = (
38
+ f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
39
+ f"<|im_start|>user\n{question}<|im_end|>\n"
40
+ f"<|im_start|>assistant\n"
41
+ )
42
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
43
+ with torch.no_grad():
44
+ outputs = model.generate(
45
+ **inputs,
46
+ max_new_tokens=300,
47
+ do_sample=False,
48
+ pad_token_id=tokenizer.eos_token_id,
49
+ )
50
+ return tokenizer.decode(
51
+ outputs[0][inputs["input_ids"].shape[1]:],
52
+ skip_special_tokens=True
53
+ ).strip()
54
+
55
+
56
+ # Load model
57
+ with st.spinner("Loading model... (first time takes 2-3 minutes)"):
58
+ model, tokenizer = load_model()
59
+
60
+ st.success("Model loaded!")
61
+ st.markdown("---")
62
+
63
+ # Chat interface
64
+ if "messages" not in st.session_state:
65
+ st.session_state.messages = []
66
+
67
+ # Display chat history
68
+ for msg in st.session_state.messages:
69
+ with st.chat_message(msg["role"]):
70
+ st.markdown(msg["content"])
71
+
72
+ # User input
73
+ user_input = st.chat_input("Ask a question about Section 8.1...")
74
+
75
+ if user_input:
76
+ # Show user message
77
+ st.session_state.messages.append({"role": "user", "content": user_input})
78
+ with st.chat_message("user"):
79
+ st.markdown(user_input)
80
+
81
+ # Get model answer
82
+ with st.chat_message("assistant"):
83
+ with st.spinner("Thinking..."):
84
+ answer = ask(user_input, model, tokenizer)
85
+ st.markdown(answer)
86
+
87
+ st.session_state.messages.append({"role": "assistant", "content": answer})
88
+
89
+ # Sidebar
90
+ with st.sidebar:
91
+ st.header("About")
92
+ st.markdown("""
93
+ This model is **reinforcement fine-tuned (DPO + SFT)**
94
+ exclusively on **Section 8.1** of the Income Tax
95
+ Assessment Act 1997 (General Deductions).
96
+
97
+ **It will:**
98
+ - Answer questions about Section 8.1
99
+ - Refuse questions about other sections
100
+ - Correct wrong details in questions
101
+
102
+ **It will NOT:**
103
+ - Answer questions outside Section 8.1
104
+ - Add information not in the section
105
+ - Make up dollar amounts or rules
106
+ """)
107
+
108
+ st.markdown("---")
109
+ st.markdown("**Model:** `muhammadjasim12/rainforcejasim`")
110
+ st.markdown("**Base:** Qwen2.5-7B-Instruct")
111
+ st.markdown("**Training:** DPO + SFT")
112
+
113
+ st.markdown("---")
114
+ st.header("Example Questions")
115
+ st.markdown("""
116
+ - What is Section 8.1 about?
117
+ - What does Section 8.1(1)(a) say?
118
+ - Can I deduct a capital expense?
119
+ - What does Section 8.2 say?
120
+ - Does Section 8.1 have four subsections?
121
+ - What is Division 7A?
122
+ """)
123
+
124
+ if st.button("Clear Chat"):
125
+ st.session_state.messages = []
126
+ st.rerun()