Didrik Nathaniel LLoyd Aasland Skjelbred commited on
Commit
f2bb5c6
·
1 Parent(s): 45e12ad
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -6,19 +6,22 @@ model_name = "tiiuae/falcon-rw-1b"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1) # -1 = CPU
10
 
11
  chat_history = []
 
12
 
13
  def generate_reply(message):
14
  global chat_history
15
  chat_history.append(f"User: {message}")
16
  prompt = "\n".join(chat_history) + "\nBot:"
17
-
18
- result = generator(prompt, max_new_tokens=100, do_sample=True)
19
- reply = result[0]["generated_text"].split("Bot:")[-1].strip()
 
20
 
21
  chat_history.append(f"Bot: {reply}")
 
22
  return reply
23
 
24
  with gr.Blocks() as demo:
@@ -27,6 +30,5 @@ with gr.Blocks() as demo:
27
 
28
  txt.submit(generate_reply, inputs=txt, outputs=out).api_name = "generate_reply"
29
 
30
-
31
  demo.queue()
32
  demo.launch(share=True, show_api=True, show_error=True)
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
10
 
11
  chat_history = []
12
+ MAX_HISTORY = 10 # Optional: to limit memory growth
13
 
14
  def generate_reply(message):
15
  global chat_history
16
  chat_history.append(f"User: {message}")
17
  prompt = "\n".join(chat_history) + "\nBot:"
18
+
19
+ result = generator(prompt, max_new_tokens=100, do_sample=True, pad_token_id=tokenizer.eos_token_id)
20
+ generated = result[0]["generated_text"]
21
+ reply = generated[len(prompt):].split("User:")[0].strip()
22
 
23
  chat_history.append(f"Bot: {reply}")
24
+ chat_history[:] = chat_history[-MAX_HISTORY:] # Trim history
25
  return reply
26
 
27
  with gr.Blocks() as demo:
 
30
 
31
  txt.submit(generate_reply, inputs=txt, outputs=out).api_name = "generate_reply"
32
 
 
33
  demo.queue()
34
  demo.launch(share=True, show_api=True, show_error=True)