ranamhamoud commited on
Commit
1903e15
·
verified ·
1 Parent(s): 46868f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -94,13 +94,13 @@ def generate(
94
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
95
  conversation.append({"role": "user", "content": message})
96
 
97
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
98
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
99
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
100
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
101
  input_ids = input_ids.to(model.device)
102
-
103
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
104
  generate_kwargs = dict(
105
  {"input_ids": input_ids},
106
  streamer=streamer,
@@ -117,10 +117,9 @@ def generate(
117
 
118
  outputs = []
119
  for text in streamer:
120
- processed_text = process_text(text)
121
- outputs.append(processed_text)
122
- output = "".join(outputs)
123
- yield output
124
 
125
  final_story = "".join(outputs)
126
  try:
 
94
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
95
  conversation.append({"role": "user", "content": message})
96
 
97
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
98
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
99
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
100
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
101
  input_ids = input_ids.to(model.device)
102
+
103
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
104
  generate_kwargs = dict(
105
  {"input_ids": input_ids},
106
  streamer=streamer,
 
117
 
118
  outputs = []
119
  for text in streamer:
120
+ outputs.append(text)
121
+ yield "".join(outputs)
122
+
 
123
 
124
  final_story = "".join(outputs)
125
  try: