DharavathSri commited on
Commit
e917c6f
·
verified ·
1 Parent(s): 8fc2b9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -60
app.py CHANGED
@@ -1,93 +1,137 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # Custom CSS for styling
 
 
 
 
 
 
 
 
 
6
  st.markdown("""
7
- <style>
 
8
  .main {
9
- background-color: #f5f5f5;
10
  }
11
- .stTextInput>div>div>input {
12
- background-color: #ffffff;
13
- color: #000000;
 
 
 
 
 
 
 
 
 
 
 
14
  }
 
 
15
  .stButton>button {
16
- background-color: #4CAF50;
17
  color: white;
18
- border-radius: 5px;
19
  border: none;
 
20
  padding: 10px 24px;
 
21
  }
22
- .stButton>button:hover {
23
- background-color: #45a049;
 
 
 
 
 
24
  }
25
- .title {
26
- font-size: 2.5em;
27
- color: #2c3e50;
28
- text-align: center;
29
- margin-bottom: 0.5em;
30
- }
31
- .sidebar .sidebar-content {
32
- background-color: #2c3e50;
33
  color: white;
 
 
 
34
  }
35
- </style>
36
- """, unsafe_allow_html=True)
37
-
38
- # App Title
39
- st.markdown('<p class="title">💬 Fine-Tuned LLM Chat</p>', unsafe_allow_html=True)
40
-
41
- # Sidebar for settings
42
- with st.sidebar:
43
- st.header("⚙️ Settings")
44
- model_name = st.selectbox(
45
- "Select Model",
46
- ["mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b-chat-hf"],
47
- help="Choose a pre-trained model to fine-tune."
48
- )
49
- temperature = st.slider(
50
- "Temperature",
51
- min_value=0.1,
52
- max_value=1.0,
53
- value=0.7,
54
- help="Controls randomness (lower = more deterministic)."
55
- )
56
- max_length = st.slider(
57
- "Max Response Length",
58
- min_value=50,
59
- max_value=500,
60
- value=150,
61
- help="Maximum number of tokens in the response."
62
- )
63
 
64
- # Load model (cached to avoid reloading)
 
 
65
  @st.cache_resource
66
- def load_model(model_name):
 
67
  tokenizer = AutoTokenizer.from_pretrained(model_name)
68
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
 
 
 
 
69
  return tokenizer, model
70
 
71
- tokenizer, model = load_model(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Initialize chat history
74
  if "messages" not in st.session_state:
75
- st.session_state.messages = []
 
 
76
 
77
  # Display chat messages
78
  for message in st.session_state.messages:
79
- with st.chat_message(message["role"]):
80
- st.markdown(message["content"])
 
 
 
 
81
 
82
  # Chat input
83
- if prompt := st.chat_input("Ask me anything..."):
84
  st.session_state.messages.append({"role": "user", "content": prompt})
 
85
  with st.chat_message("user"):
86
- st.markdown(prompt)
87
-
88
  # Generate response
89
  with st.chat_message("assistant"):
90
- with st.spinner("Thinking..."):
91
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
92
  outputs = model.generate(
93
  **inputs,
@@ -96,5 +140,15 @@ if prompt := st.chat_input("Ask me anything..."):
96
  do_sample=True
97
  )
98
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
- st.markdown(response)
100
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
4
 
5
+ # ======================
6
+ # 🎨 STYLING & LAYOUT
7
+ # ======================
8
+ st.set_page_config(
9
+ page_title="LLM Fine-Tuning Studio",
10
+ page_icon="🧠",
11
+ layout="wide"
12
+ )
13
+
14
+ # Custom CSS
15
  st.markdown("""
16
+ <style>
17
+ /* Main container */
18
  .main {
19
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
20
  }
21
+
22
+ /* Headers */
23
+ h1 {
24
+ color: #2c3e50;
25
+ text-align: center;
26
+ font-family: 'Arial', sans-serif;
27
+ border-bottom: 2px solid #4CAF50;
28
+ padding-bottom: 10px;
29
+ }
30
+
31
+ /* Sidebar */
32
+ [data-testid="stSidebar"] {
33
+ background: linear-gradient(195deg, #2c3e50 0%, #4CAF50 100%) !important;
34
+ color: white;
35
  }
36
+
37
+ /* Buttons */
38
  .stButton>button {
39
+ background: linear-gradient(to right, #4CAF50, #2E8B57);
40
  color: white;
 
41
  border: none;
42
+ border-radius: 25px;
43
  padding: 10px 24px;
44
+ font-weight: bold;
45
  }
46
+
47
+ /* Chat bubbles */
48
+ .user-message {
49
+ background: #e3f2fd;
50
+ border-radius: 15px 15px 0 15px;
51
+ padding: 12px;
52
+ margin: 5px 0;
53
  }
54
+
55
+ .bot-message {
56
+ background: #4CAF50;
 
 
 
 
 
57
  color: white;
58
+ border-radius: 15px 15px 15px 0;
59
+ padding: 12px;
60
+ margin: 5px 0;
61
  }
62
+
63
+ /* Input box */
64
+ .stTextInput>div>div>input {
65
+ border-radius: 20px !important;
66
+ padding: 10px 15px !important;
67
+ }
68
+ </style>
69
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # ======================
72
+ # 🧠 MODEL LOADING
73
+ # ======================
74
  @st.cache_resource
75
+ def load_model():
76
+ model_name = "mistralai/Mistral-7B-v0.1"
77
  tokenizer = AutoTokenizer.from_pretrained(model_name)
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ model_name,
80
+ torch_dtype=torch.float16,
81
+ device_map="auto"
82
+ )
83
  return tokenizer, model
84
 
85
+ tokenizer, model = load_model()
86
+
87
+ # ======================
88
+ # 🎛️ SIDEBAR CONTROLS
89
+ # ======================
90
+ with st.sidebar:
91
+ st.title("⚙️ Fine-Tuning Controls")
92
+
93
+ st.subheader("Model Parameters")
94
+ temperature = st.slider("Temperature", 0.1, 1.0, 0.7, 0.05)
95
+ max_length = st.slider("Max Length", 50, 500, 150)
96
+
97
+ st.subheader("Fine-Tuning Options")
98
+ use_lora = st.checkbox("Use LoRA", True)
99
+ quantize = st.selectbox("Quantization", ["None", "4-bit", "8-bit"])
100
+
101
+ if st.button("🔄 Apply Changes"):
102
+ st.toast("Settings updated!", icon="✅")
103
+
104
+ # ======================
105
+ # 💬 MAIN CHAT INTERFACE
106
+ # ======================
107
+ st.title("🧠 LLM Fine-Tuning Studio")
108
+ st.caption("Fine-tune and deploy state-of-the-art language models")
109
 
110
  # Initialize chat history
111
  if "messages" not in st.session_state:
112
+ st.session_state.messages = [
113
+ {"role": "assistant", "content": "Hello! I'm your fine-tuned AI assistant. How can I help you today?"}
114
+ ]
115
 
116
  # Display chat messages
117
  for message in st.session_state.messages:
118
+ if message["role"] == "assistant":
119
+ with st.chat_message("assistant"):
120
+ st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
121
+ else:
122
+ with st.chat_message("user"):
123
+ st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
124
 
125
  # Chat input
126
+ if prompt := st.chat_input("Type your message..."):
127
  st.session_state.messages.append({"role": "user", "content": prompt})
128
+
129
  with st.chat_message("user"):
130
+ st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True)
131
+
132
  # Generate response
133
  with st.chat_message("assistant"):
134
+ with st.spinner("🧠 Thinking..."):
135
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
136
  outputs = model.generate(
137
  **inputs,
 
140
  do_sample=True
141
  )
142
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
143
+
144
+ st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
145
+ st.session_state.messages.append({"role": "assistant", "content": response})
146
+
147
+ # ======================
148
+ # 📊 FINE-TUNE STATUS
149
+ # ======================
150
+ st.sidebar.markdown("---")
151
+ st.sidebar.subheader("Training Metrics")
152
+ st.sidebar.metric("Loss", "0.45", delta="-0.02")
153
+ st.sidebar.metric("Accuracy", "87%", delta="+2%")
154
+ st.sidebar.progress(75)