pradeep4321 commited on
Commit
6d9b383
·
verified ·
1 Parent(s): f3c6eb8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +22 -7
src/streamlit_app.py CHANGED
@@ -14,7 +14,7 @@ st.title("🤖 Simple AI Assistant")
14
  # ==============================
15
  @st.cache_resource
16
  def load_model():
17
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 🔥 BEST FOR HF FREE
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
@@ -41,32 +41,44 @@ if "messages" not in st.session_state:
41
  # CLEAN TEXT
42
  # ==============================
43
  def clean_text(text):
44
- text = re.sub(r"[^\x00-\x7F]+", "", text)
45
- return text.strip()
 
 
 
 
 
46
 
47
  # ==============================
48
  # GENERATE RESPONSE
49
  # ==============================
50
  def generate_response(user_input):
51
 
52
- prompt = f"<|user|>\n{user_input}\n<|assistant|>\n"
 
 
 
 
 
 
53
 
54
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
55
 
56
  with torch.no_grad():
57
  outputs = model.generate(
58
  **inputs,
59
- max_new_tokens=150,
60
  do_sample=True,
61
  temperature=0.7,
62
  top_p=0.9,
63
  repetition_penalty=1.1,
 
64
  pad_token_id=tokenizer.eos_token_id
65
  )
66
 
67
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
 
69
- # Extract assistant part
70
  if "<|assistant|>" in result:
71
  result = result.split("<|assistant|>")[-1]
72
 
@@ -80,19 +92,22 @@ for msg in st.session_state.messages:
80
  st.markdown(msg["content"])
81
 
82
  # ==============================
83
- # INPUT
84
  # ==============================
85
  user_input = st.chat_input("Type your message...")
86
 
87
  if user_input:
 
88
  st.session_state.messages.append({"role": "user", "content": user_input})
89
 
90
  with st.chat_message("user"):
91
  st.markdown(user_input)
92
 
 
93
  with st.spinner("🤖 Thinking..."):
94
  response = generate_response(user_input)
95
 
 
96
  st.session_state.messages.append({"role": "assistant", "content": response})
97
 
98
  with st.chat_message("assistant"):
 
14
  # ==============================
15
  @st.cache_resource
16
  def load_model():
17
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Best for HF free
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
 
41
  # CLEAN TEXT
42
  # ==============================
43
  def clean_text(text):
44
+ text = re.sub(r"[^\x00-\x7F]+", "", text).strip()
45
+
46
+ # Ensure response completes nicely
47
+ if not text.endswith((".", "!", "?")):
48
+ text += "..."
49
+
50
+ return text
51
 
52
  # ==============================
53
  # GENERATE RESPONSE
54
  # ==============================
55
  def generate_response(user_input):
56
 
57
+ prompt = f"""
58
+ <|user|>
59
+ {user_input}
60
+
61
+ Give a clear and complete answer.
62
+ <|assistant|>
63
+ """
64
 
65
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
66
 
67
  with torch.no_grad():
68
  outputs = model.generate(
69
  **inputs,
70
+ max_new_tokens=300, # 🔥 prevents cut-off
71
  do_sample=True,
72
  temperature=0.7,
73
  top_p=0.9,
74
  repetition_penalty=1.1,
75
+ eos_token_id=tokenizer.eos_token_id,
76
  pad_token_id=tokenizer.eos_token_id
77
  )
78
 
79
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
 
81
+ # Extract assistant response
82
  if "<|assistant|>" in result:
83
  result = result.split("<|assistant|>")[-1]
84
 
 
92
  st.markdown(msg["content"])
93
 
94
  # ==============================
95
+ # INPUT BOX
96
  # ==============================
97
  user_input = st.chat_input("Type your message...")
98
 
99
  if user_input:
100
+ # Add user message
101
  st.session_state.messages.append({"role": "user", "content": user_input})
102
 
103
  with st.chat_message("user"):
104
  st.markdown(user_input)
105
 
106
+ # Generate response
107
  with st.spinner("🤖 Thinking..."):
108
  response = generate_response(user_input)
109
 
110
+ # Add assistant response
111
  st.session_state.messages.append({"role": "assistant", "content": response})
112
 
113
  with st.chat_message("assistant"):