ammoncoder123 commited on
Commit
c12ee4d
·
verified ·
1 Parent(s): fb13fa4

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +15 -36
chatbot.py CHANGED
@@ -5,23 +5,26 @@ import torch
5
  # ================= CACHE THE MODEL =================
6
  @st.cache_resource
7
  def load_model():
8
- model_id = "ammoncoder123/IPTchatbotModel1-1.7B" # ← Your correct model repo
9
 
10
- # 4-bit quantization for memory efficiency (required for 1.7B on GPU)
 
 
 
11
  quantization_config = BitsAndBytesConfig(
12
  load_in_4bit=True,
13
  bnb_4bit_compute_dtype=torch.float16
14
  )
15
 
16
- tokenizer = AutoTokenizer.from_pretrained(model_id)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_id,
19
  quantization_config=quantization_config,
20
- device_map="auto", # Automatically uses GPU if available
21
  torch_dtype=torch.float16,
22
- trust_remote_code=True # Sometimes needed for custom models
23
  )
24
 
 
25
  return pipeline(
26
  "text-generation",
27
  model=model,
@@ -32,60 +35,36 @@ def load_model():
32
  top_p=0.9
33
  )
34
 
35
- # Load model once (this will run on first use)
36
  pipe = load_model()
37
 
38
  # ==================== CHAT INTERFACE ====================
39
- st.title("IPT Chatbot (1.7B Fine-Tuned Model)")
40
 
41
- # Show a disclaimer
42
- st.info("⚠️ This is a small fine-tuned model (1.7B parameters). Answers may contain inaccuracies. Always verify important information.")
43
 
44
- # Initialize chat history
45
  if "messages" not in st.session_state:
46
  st.session_state.messages = []
47
 
48
- # Display chat history
49
  for message in st.session_state.messages:
50
  with st.chat_message(message["role"]):
51
  st.markdown(message["content"])
52
 
53
- # User input
54
- if prompt := st.chat_input("Ask me about IPT, ICT, or anything else..."):
55
- # Add user message
56
  st.session_state.messages.append({"role": "user", "content": prompt})
57
  with st.chat_message("user"):
58
  st.markdown(prompt)
59
 
60
- # Generate response
61
  with st.chat_message("assistant"):
62
  with st.spinner("Thinking..."):
63
- # Use proper chat format for Instruct models
64
- chat_messages = [
65
- {"role": "user", "content": prompt}
66
- ]
67
-
68
- outputs = pipe(
69
- chat_messages,
70
- max_new_tokens=300,
71
- temperature=0.7,
72
- do_sample=True,
73
- top_p=0.9
74
- )
75
-
76
- # Extract generated text
77
  response = outputs[0]["generated_text"]
78
-
79
- # Clean up echoed prompt
80
- if isinstance(response, str) and response.startswith(prompt):
81
  response = response[len(prompt):].strip()
82
-
83
  st.markdown(response)
84
 
85
- # Save assistant response
86
  st.session_state.messages.append({"role": "assistant", "content": response})
87
 
88
- # Optional: Clear chat button
89
- if st.button("Clear Conversation"):
90
  st.session_state.messages = []
91
  st.rerun()
 
5
  # ================= CACHE THE MODEL =================
6
  @st.cache_resource
7
  def load_model():
8
+ model_id = "ammoncoder123/IPTchatbotModel1-1.7B" # ← YOUR REAL REPO (copy-paste this exactly!)
9
 
10
+ st.write("Loading tokenizer...")
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+
13
+ st.write("Loading model (this may take a few minutes the first time)...")
14
  quantization_config = BitsAndBytesConfig(
15
  load_in_4bit=True,
16
  bnb_4bit_compute_dtype=torch.float16
17
  )
18
 
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  quantization_config=quantization_config,
22
+ device_map="auto", # GPU if available, else CPU
23
  torch_dtype=torch.float16,
24
+ trust_remote_code=True # Safe for most models
25
  )
26
 
27
+ st.write("Model loaded successfully!")
28
  return pipeline(
29
  "text-generation",
30
  model=model,
 
35
  top_p=0.9
36
  )
37
 
 
38
  pipe = load_model()
39
 
40
  # ==================== CHAT INTERFACE ====================
41
+ st.title("My 1.7B Fine-Tuned IPT Chatbot")
42
 
43
+ st.info("⚠️ Small fine-tuned model (1.7B). Answers may vary — verify important info.")
 
44
 
 
45
  if "messages" not in st.session_state:
46
  st.session_state.messages = []
47
 
 
48
  for message in st.session_state.messages:
49
  with st.chat_message(message["role"]):
50
  st.markdown(message["content"])
51
 
52
+ if prompt := st.chat_input("Ask about IPT, ICT, or anything..."):
 
 
53
  st.session_state.messages.append({"role": "user", "content": prompt})
54
  with st.chat_message("user"):
55
  st.markdown(prompt)
56
 
 
57
  with st.chat_message("assistant"):
58
  with st.spinner("Thinking..."):
59
+ chat_messages = [{"role": "user", "content": prompt}]
60
+ outputs = pipe(chat_messages, max_new_tokens=300, temperature=0.7, do_sample=True, top_p=0.9)
 
 
 
 
 
 
 
 
 
 
 
 
61
  response = outputs[0]["generated_text"]
62
+ if response.startswith(prompt):
 
 
63
  response = response[len(prompt):].strip()
 
64
  st.markdown(response)
65
 
 
66
  st.session_state.messages.append({"role": "assistant", "content": response})
67
 
68
+ if st.button("Clear Chat"):
 
69
  st.session_state.messages = []
70
  st.rerun()