Yixiao Wang (Computer Science) commited on
Commit
2376772
·
1 Parent(s): 82800a1
Files changed (2) hide show
  1. app.py +3 -0
  2. requirements.txt +1 -0
app.py CHANGED
@@ -10,6 +10,7 @@ from outlines import Generator
10
  from peft import PeftConfig, PeftModel
11
  from pydantic import BaseModel, ConfigDict
12
  from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
 
13
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
@@ -104,6 +105,7 @@ def format_prompt(story: str, question: str, grading_scheme: str, answer: str) -
104
  return full_prompt
105
 
106
 
 
107
  def label_single_response(story, question, criteria, response):
108
  prompt = format_prompt(story, question, criteria, response)
109
 
@@ -122,6 +124,7 @@ def label_single_response(story, question, criteria, response):
122
  return result.score
123
 
124
 
 
125
  def label_multi_responses(story, question, criteria, response_file):
126
  df = pd.read_csv(response_file.name)
127
  assert "response" in df.columns, "CSV must contain a 'response' column."
 
10
  from peft import PeftConfig, PeftModel
11
  from pydantic import BaseModel, ConfigDict
12
  from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
13
+ import spaces
14
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
 
105
  return full_prompt
106
 
107
 
108
+ @spaces.GPU
109
  def label_single_response(story, question, criteria, response):
110
  prompt = format_prompt(story, question, criteria, response)
111
 
 
124
  return result.score
125
 
126
 
127
+ @spaces.GPU
128
  def label_multi_responses(story, question, criteria, response_file):
129
  df = pd.read_csv(response_file.name)
130
  assert "response" in df.columns, "CSV must contain a 'response' column."
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  huggingface_hub==0.25.2
 
2
  transformers
3
  gradio
4
  peft
 
1
  huggingface_hub==0.25.2
2
+ spaces
3
  transformers
4
  gradio
5
  peft