davidberenstein1957 commited on
Commit
26647e2
·
verified ·
1 Parent(s): d6ff8e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -9
app.py CHANGED
@@ -4,24 +4,40 @@ import random
4
  from dataset_viber import AnnotatorInterFace
5
  from datasets import load_dataset
6
  from huggingface_hub import InferenceClient
 
7
 
8
  # https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
9
  MODEL_IDS = [
10
- "sci-m-wang/Phi-3-mini-4k-instruct-sa-v0.1",
11
  "microsoft/Phi-3-mini-4k-instruct"
12
  ]
13
  CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]
14
 
15
  dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train")
16
 
17
- def _get_response(messages):
18
- client = random.choice(CLIENTS)
19
- message = client.chat_completion(
20
- messages=messages,
21
- stream=False,
22
- max_tokens=2000
23
- )
24
- return message.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def next_input(_prompt, _completion_a, _completion_b):
27
  new_dataset = dataset.shuffle()
 
4
  from dataset_viber import AnnotatorInterFace
5
  from datasets import load_dataset
6
  from huggingface_hub import InferenceClient
7
+ import time
8
 
9
  # https://huggingface.co/models?inference=warm&pipeline_tag=text-generation&sort=trending
10
  MODEL_IDS = [
11
+ "google/gemma-2b-it",
12
  "microsoft/Phi-3-mini-4k-instruct"
13
  ]
14
  CLIENTS = [InferenceClient(model_id, token=os.environ["HF_TOKEN"]) for model_id in MODEL_IDS]
15
 
16
  dataset = load_dataset("argilla/magpie-ultra-v0.1", split="train")
17
 
18
+
19
+ def get_response(messages):
20
+ max_retries = 3
21
+ retry_delay = 3
22
+
23
+ for attempt in range(max_retries):
24
+ try:
25
+ client = random.choice(CLIENTS)
26
+ message = client.chat_completion(
27
+ messages=messages,
28
+ stream=False,
29
+ max_tokens=2000
30
+ )
31
+ return message.choices[0].message.content
32
+ except Exception as e:
33
+ if attempt < max_retries - 1:
34
+ print(f"An error occurred: {e}. Retrying in {retry_delay} seconds...")
35
+ time.sleep(retry_delay)
36
+ else:
37
+ print(f"Max retries reached. Last error: {e}")
38
+ raise
39
+
40
+ return None # This line will only be reached if all retries fail
41
 
42
  def next_input(_prompt, _completion_a, _completion_b):
43
  new_dataset = dataset.shuffle()