b2u commited on
Commit
4b117ce
·
1 Parent(s): 51dd92c

adding prompt

Browse files
Files changed (1) hide show
  1. model.py +40 -17
model.py CHANGED
@@ -132,42 +132,65 @@ class T5Model(LabelStudioMLBase):
132
  choices = root.findall('.//Choice')
133
  return [choice.get('value') for choice in choices]
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
136
  """Generate predictions using T5 model"""
137
  logger.info("Received prediction request")
138
  logger.info(f"Tasks: {json.dumps(tasks, indent=2)}")
139
- logger.info(f"Context: {json.dumps(context, indent=2) if context else None}")
140
- logger.info(f"Additional kwargs: {kwargs}")
141
 
142
  predictions = []
143
- # Get valid choices during initialization instead of from kwargs
144
- valid_choices = []
145
  try:
146
- import xml.etree.ElementTree as ET
147
- root = ET.fromstring(self.label_config)
148
- choices = root.findall('.//Choice')
149
- valid_choices = [choice.get('value') for choice in choices]
150
  logger.info(f"Valid choices: {valid_choices}")
151
  except Exception as e:
152
  logger.error(f"Error parsing choices: {str(e)}")
153
- valid_choices = ["no_category"] # fallback
 
154
 
155
  try:
156
  for task in tasks:
157
- logger.info(f"Processing task: {json.dumps(task, indent=2)}")
158
-
159
  input_text = task['data'].get(self.to_name)
160
- logger.info(f"Input text: {input_text}")
161
- logger.info(f"Using to_name: {self.to_name}")
162
-
163
  if not input_text:
164
  logger.warning(f"No input text found using {self.to_name}")
165
- logger.warning(f"Available fields in task data: {list(task['data'].keys())}")
166
  continue
167
 
168
- # Generate prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  inputs = self.tokenizer(
170
- input_text,
171
  return_tensors="pt",
172
  max_length=self.max_length,
173
  truncation=True,
 
132
  choices = root.findall('.//Choice')
133
  return [choice.get('value') for choice in choices]
134
 
135
+ def get_categories_with_hints(self, label_config):
136
+ """Extract categories and their hints from label config"""
137
+ import xml.etree.ElementTree as ET
138
+
139
+ root = ET.fromstring(label_config)
140
+ choices = root.findall('.//Choice')
141
+ categories = []
142
+ for choice in choices:
143
+ categories.append({
144
+ 'value': choice.get('value'),
145
+ 'hint': choice.get('hint')
146
+ })
147
+ return categories
148
+
149
  def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
150
  """Generate predictions using T5 model"""
151
  logger.info("Received prediction request")
152
  logger.info(f"Tasks: {json.dumps(tasks, indent=2)}")
 
 
153
 
154
  predictions = []
155
+ # Get categories with their descriptions
 
156
  try:
157
+ categories = self.get_categories_with_hints(self.label_config)
158
+ valid_choices = [cat['value'] for cat in categories]
159
+ category_descriptions = [f"{cat['value']}: {cat['hint']}" for cat in categories]
 
160
  logger.info(f"Valid choices: {valid_choices}")
161
  except Exception as e:
162
  logger.error(f"Error parsing choices: {str(e)}")
163
+ valid_choices = ["no_category"]
164
+ category_descriptions = ["no_category: Default category when no others apply"]
165
 
166
  try:
167
  for task in tasks:
 
 
168
  input_text = task['data'].get(self.to_name)
 
 
 
169
  if not input_text:
170
  logger.warning(f"No input text found using {self.to_name}")
 
171
  continue
172
 
173
+ # Format prompt with input text and category descriptions
174
+ prompt = f"""Classify the following text into exactly one category.
175
+
176
+ Available categories with descriptions:
177
+ {chr(10).join(f"- {desc}" for desc in category_descriptions)}
178
+
179
+ Text to classify: {input_text}
180
+
181
+ Instructions:
182
+ 1. Consider the text carefully
183
+ 2. Choose the most appropriate category from the list
184
+ 3. Return ONLY the category value (e.g. 'business_and_career', 'date', etc.)
185
+ 4. Do not add any explanations or additional text
186
+
187
+ Category:"""
188
+
189
+ logger.info(f"Generated prompt: {prompt}")
190
+
191
+ # Generate prediction with prompt
192
  inputs = self.tokenizer(
193
+ prompt,
194
  return_tensors="pt",
195
  max_length=self.max_length,
196
  truncation=True,