Spaces:
Runtime error
Runtime error
kz209
commited on
Commit
·
9dfac6e
1
Parent(s):
34ffea3
update
Browse files- pages/arena.py +1 -1
- utils/model.py +2 -2
pages/arena.py
CHANGED
|
@@ -22,7 +22,7 @@ def create_arena():
|
|
| 22 |
submit_button = gr.Button("✨ Submit ✨")
|
| 23 |
|
| 24 |
with gr.Row():
|
| 25 |
-
columns = [gr.Textbox(label=f"
|
| 26 |
|
| 27 |
content_list = [prompt + '\n{' + datapoint + '}\n\nsummary:' for prompt in prompts]
|
| 28 |
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
|
|
|
|
| 22 |
submit_button = gr.Button("✨ Submit ✨")
|
| 23 |
|
| 24 |
with gr.Row():
|
| 25 |
+
columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(prompts))]
|
| 26 |
|
| 27 |
content_list = [prompt + '\n{' + datapoint + '}\n\nsummary:' for prompt in prompts]
|
| 28 |
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
|
utils/model.py
CHANGED
|
@@ -55,7 +55,7 @@ class Model(torch.nn.Module):
|
|
| 55 |
def return_model(self):
|
| 56 |
return self.pipeline
|
| 57 |
|
| 58 |
-
def gen(self, content_list, temp=0.
|
| 59 |
# Convert list of texts to input IDs
|
| 60 |
input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
|
| 61 |
|
|
@@ -74,7 +74,7 @@ class Model(torch.nn.Module):
|
|
| 74 |
return_dict_in_generate=True,
|
| 75 |
output_scores=True,
|
| 76 |
streamer=streamer):
|
| 77 |
-
|
| 78 |
else:
|
| 79 |
outputs = self.model.generate(
|
| 80 |
input_ids,
|
|
|
|
| 55 |
def return_model(self):
|
| 56 |
return self.pipeline
|
| 57 |
|
| 58 |
+
def gen(self, content_list, temp=0.001, max_length=500, streaming=False):
|
| 59 |
# Convert list of texts to input IDs
|
| 60 |
input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
|
| 61 |
|
|
|
|
| 74 |
return_dict_in_generate=True,
|
| 75 |
output_scores=True,
|
| 76 |
streamer=streamer):
|
| 77 |
+
yield output # TextStreamer automatically handles the streaming, no need to manually handle the output
|
| 78 |
else:
|
| 79 |
outputs = self.model.generate(
|
| 80 |
input_ids,
|