celikn commited on
Commit
f591de0
·
verified ·
1 Parent(s): dda1c63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -82,12 +82,16 @@ def hf_generate(model_name, prompt, max_new_tokens=256, temperature=0.2):
82
  client = InferenceClient(model=model_name, token=HF_TOKEN)
83
  start = time.time()
84
  try:
85
- output = client.text_generation(prompt, max_new_tokens=max_new_tokens, temperature=temperature)
 
 
 
86
  latency = time.time() - start
87
  return output.strip(), latency
88
  except Exception as e:
89
  return f"ERROR: {e}", time.time() - start
90
 
 
91
  # ---------------- Benchmark Function ---------------- #
92
  def benchmark(config_text, dataset_text, task):
93
  cfg = yaml.safe_load(config_text)
 
82
  client = InferenceClient(model=model_name, token=HF_TOKEN)
83
  start = time.time()
84
  try:
85
+ if "flan" in model_name or "t5" in model_name:
86
+ output = client.text2text_generation(prompt, max_new_tokens=max_new_tokens)
87
+ else:
88
+ output = client.text_generation(prompt, max_new_tokens=max_new_tokens, temperature=temperature)
89
  latency = time.time() - start
90
  return output.strip(), latency
91
  except Exception as e:
92
  return f"ERROR: {e}", time.time() - start
93
 
94
+
95
  # ---------------- Benchmark Function ---------------- #
96
  def benchmark(config_text, dataset_text, task):
97
  cfg = yaml.safe_load(config_text)