ammoncoder123 commited on
Commit
fb13fa4
·
verified ·
1 Parent(s): c0db93e

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +69 -10
chatbot.py CHANGED
@@ -1,32 +1,91 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
 
 
4
  @st.cache_resource
5
  def load_model():
6
- # Use a public 7B model (excellent quality)
7
- return pipeline("text-generation",model_id = "ammoncoder123/IPTchatbotModel1-1.7B")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pipe = load_model()
10
 
11
- st.title("My 1.7B Smart Chatbot")
 
12
 
 
 
 
 
13
  if "messages" not in st.session_state:
14
  st.session_state.messages = []
15
 
16
- for msg in st.session_state.messages:
17
- with st.chat_message(msg["role"]):
18
- st.markdown(msg["content"])
 
19
 
20
- if prompt := st.chat_input("Ask about IPT/ICT..."):
 
 
21
  st.session_state.messages.append({"role": "user", "content": prompt})
22
  with st.chat_message("user"):
23
  st.markdown(prompt)
24
 
 
25
  with st.chat_message("assistant"):
26
  with st.spinner("Thinking..."):
27
- response = pipe(prompt, max_new_tokens=300)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  st.markdown(response)
29
 
 
30
  st.session_state.messages.append({"role": "assistant", "content": response})
31
 
32
- st.info("This is a 1.7B model demo — answers are generally accurate but verify important facts.")
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
3
+ import torch
4
 
5
+ # ================= CACHE THE MODEL =================
6
  @st.cache_resource
7
  def load_model():
8
+ model_id = "ammoncoder123/IPTchatbotModel1-1.7B" # Your correct model repo
 
9
 
10
+ # 4-bit quantization for memory efficiency (required for 1.7B on GPU)
11
+ quantization_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_compute_dtype=torch.float16
14
+ )
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ quantization_config=quantization_config,
20
+ device_map="auto", # Automatically uses GPU if available
21
+ torch_dtype=torch.float16,
22
+ trust_remote_code=True # Sometimes needed for custom models
23
+ )
24
+
25
+ return pipeline(
26
+ "text-generation",
27
+ model=model,
28
+ tokenizer=tokenizer,
29
+ max_new_tokens=300,
30
+ temperature=0.7,
31
+ do_sample=True,
32
+ top_p=0.9
33
+ )
34
+
35
+ # Load model once (this will run on first use)
36
  pipe = load_model()
37
 
38
+ # ==================== CHAT INTERFACE ====================
39
+ st.title("IPT Chatbot (1.7B Fine-Tuned Model)")
40
 
41
+ # Show a disclaimer
42
+ st.info("⚠️ This is a small fine-tuned model (1.7B parameters). Answers may contain inaccuracies. Always verify important information.")
43
+
44
+ # Initialize chat history
45
  if "messages" not in st.session_state:
46
  st.session_state.messages = []
47
 
48
+ # Display chat history
49
+ for message in st.session_state.messages:
50
+ with st.chat_message(message["role"]):
51
+ st.markdown(message["content"])
52
 
53
+ # User input
54
+ if prompt := st.chat_input("Ask me about IPT, ICT, or anything else..."):
55
+ # Add user message
56
  st.session_state.messages.append({"role": "user", "content": prompt})
57
  with st.chat_message("user"):
58
  st.markdown(prompt)
59
 
60
+ # Generate response
61
  with st.chat_message("assistant"):
62
  with st.spinner("Thinking..."):
63
+ # Use proper chat format for Instruct models
64
+ chat_messages = [
65
+ {"role": "user", "content": prompt}
66
+ ]
67
+
68
+ outputs = pipe(
69
+ chat_messages,
70
+ max_new_tokens=300,
71
+ temperature=0.7,
72
+ do_sample=True,
73
+ top_p=0.9
74
+ )
75
+
76
+ # Extract generated text
77
+ response = outputs[0]["generated_text"]
78
+
79
+ # Clean up echoed prompt
80
+ if isinstance(response, str) and response.startswith(prompt):
81
+ response = response[len(prompt):].strip()
82
+
83
  st.markdown(response)
84
 
85
+ # Save assistant response
86
  st.session_state.messages.append({"role": "assistant", "content": response})
87
 
88
+ # Optional: Clear chat button
89
+ if st.button("Clear Conversation"):
90
+ st.session_state.messages = []
91
+ st.rerun()