Ram7379 commited on
Commit
0c3f035
·
verified ·
1 Parent(s): 22e295b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -15,37 +15,33 @@ chat_history_ids = None
15
 
16
  def chat(user_input, history):
17
  global chat_history_ids
18
-
19
- # Encode input
20
- # Add personality + context prompt
21
- prompt = "You are a helpful, friendly, and intelligent assistant. Give clear, meaningful, and human-like responses.\nUser: " + user_input + "\nBot:"
22
-
23
- new_input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt')
24
-
25
- # Append to chat history
26
  if chat_history_ids is not None:
27
  input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
28
  else:
29
  input_ids = new_input_ids
30
-
31
- # Generate response
32
- chat_history_ids = model.generate(
33
- input_ids,
34
- max_length=1000,
35
- pad_token_id=tokenizer.eos_token_id,
36
- do_sample=True,
37
- top_k=50,
38
- top_p=0.95,
39
- temperature=0.8,
40
- repetition_penalty=1.2
41
- )
42
-
43
- # Decode response
44
  response = tokenizer.decode(
45
  chat_history_ids[:, input_ids.shape[-1]:][0],
46
  skip_special_tokens=True
47
  )
48
-
49
  return response
50
 
51
  # Gradio UI
 
15
 
16
  def chat(user_input, history):
17
  global chat_history_ids
18
+
19
+ # Add personality prompt
20
+ prompt = "You are a helpful, friendly, and intelligent assistant. Give clear, meaningful, and human-like responses.\nUser: " + user_input + "\nBot:"
21
+
22
+ new_input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt').to(device)
23
+
 
 
24
  if chat_history_ids is not None:
25
  input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
26
  else:
27
  input_ids = new_input_ids
28
+
29
+ chat_history_ids = model.generate(
30
+ input_ids,
31
+ max_length=1000,
32
+ pad_token_id=tokenizer.eos_token_id,
33
+ do_sample=True,
34
+ top_k=50,
35
+ top_p=0.95,
36
+ temperature=0.8,
37
+ repetition_penalty=1.2
38
+ )
39
+
 
 
40
  response = tokenizer.decode(
41
  chat_history_ids[:, input_ids.shape[-1]:][0],
42
  skip_special_tokens=True
43
  )
44
+
45
  return response
46
 
47
  # Gradio UI