TalkUHulk commited on
Commit
13a6781
·
verified ·
1 Parent(s): 432d085

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
3
  from threading import Thread
 
4
  import re
5
  import time
6
  from PIL import Image
@@ -24,13 +25,13 @@ def model_inference(
24
  input_dict, history, decoding_strategy, temperature, max_new_tokens,
25
  repetition_penalty, top_p
26
  ):
27
- \
28
  text = input_dict["text"]
29
 
30
  if len(input_dict["files"]) > 1:
31
- images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
32
  elif len(input_dict["files"]) == 1:
33
- images = [Image.open(input_dict["files"][0]).convert("RGB")]
34
  else:
35
  images = []
36
 
@@ -78,8 +79,9 @@ def model_inference(
78
  # start token id = argmax last logit
79
  start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
80
 
81
-
82
- generated_text = ""
 
83
  generation_args = {
84
  "llm_session" : llm_session,
85
  "embed_tokens_session": embed_tokens_session,
@@ -89,14 +91,32 @@ def model_inference(
89
  "freqs_cos": freqs_cos,
90
  "freqs_sin": freqs_sin,
91
  "attention_mask": attention_mask.numpy(),
92
- "max_new_tokens": 128,
93
  "eos_token_id": 2,
94
- "start_pos": seqlen
 
95
  }
96
-
97
 
 
98
  thread = Thread(target=generate_autoregressive, kwargs=generation_args)
99
  thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
  examples = [
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
3
  from threading import Thread
4
+ from queue import Queue
5
  import re
6
  import time
7
  from PIL import Image
 
25
  input_dict, history, decoding_strategy, temperature, max_new_tokens,
26
  repetition_penalty, top_p
27
  ):
28
+ print(input_dict)
29
  text = input_dict["text"]
30
 
31
  if len(input_dict["files"]) > 1:
32
+ images = [Image.open(image).convert("RGB") for image in input_dict["files"]["path"]]
33
  elif len(input_dict["files"]) == 1:
34
+ images = [Image.open(input_dict["files"][0]["path"]).convert("RGB")]
35
  else:
36
  images = []
37
 
 
79
  # start token id = argmax last logit
80
  start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
81
 
82
+ # 创建输出队列用于线程间通信
83
+ output_queue = Queue()
84
+
85
  generation_args = {
86
  "llm_session" : llm_session,
87
  "embed_tokens_session": embed_tokens_session,
 
91
  "freqs_cos": freqs_cos,
92
  "freqs_sin": freqs_sin,
93
  "attention_mask": attention_mask.numpy(),
94
+ "max_new_tokens": max_new_tokens,
95
  "eos_token_id": 2,
96
+ "start_pos": seqlen,
97
+ "output_queue": output_queue
98
  }
 
99
 
100
+ # 在后台线程启动生成
101
  thread = Thread(target=generate_autoregressive, kwargs=generation_args)
102
  thread.start()
103
+
104
+ # 从队列中读取生成的文本并 yield
105
+ yield "..."
106
+ buffer = ""
107
+
108
+ while True:
109
+ text_chunk = output_queue.get() # 阻塞等待队列中的数据
110
+
111
+ if text_chunk is None: # 生成完成信号
112
+ break
113
+
114
+ buffer += text_chunk
115
+ time.sleep(0.01)
116
+ yield buffer
117
+
118
+ # 等待线程完成
119
+ thread.join()
120
 
121
 
122
  examples = [