File size: 1,868 Bytes
ce11ffc f20ab91 119215e 26647e2 119215e 2e1c80b 119215e d6ff8e0 119215e 9c931ea 119215e 744da58 119215e 26647e2 ce11ffc 119215e 744da58 ea1ca1e 119215e ce11ffc 119215e ae0e08b ce11ffc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import os
import random
from dataset_viber import AnnotatorInterFace
from datasets import load_dataset
from huggingface_hub import InferenceClient
import time
# https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
MODEL_IDS = [
"microsoft/Phi-3-mini-4k-instruct"
]
CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]
dataset = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train")
def get_response(messages):
max_retries = 3
retry_delay = 3
for attempt in range(max_retries):
try:
client = random.choice(CLIENTS)
message = client.chat_completion(
messages=messages,
stream=False,
max_tokens=2000
)
return message.choices[0].message.content
except Exception as e:
if attempt < max_retries - 1:
print(f"An error occurred: {e}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print(f"Max retries reached. Last error: {e}")
raise
return None # This line will only be reached if all retries fail
def next_input(_prompt, _completion_a, _completion_b):
new_dataset = dataset.shuffle()
row = new_dataset[0]
messages = row["chosen"][:-1]
completions = [row["chosen"][-1]["content"]]
completions.append(get_response(messages))
random.shuffle(completions)
return messages, completions.pop(), completions.pop()
if __name__ == "__main__":
interface = AnnotatorInterFace.for_chat_generation_preference(
fn_next_input=next_input,
interactive=[False, True, True],
dataset_name="dataset-viber-chat-generation-preference-inference-endpoints-battle",
)
interface.launch()
|