Spaces:
Paused
Paused
adding prompt
Browse files
model.py
CHANGED
|
@@ -132,42 +132,65 @@ class T5Model(LabelStudioMLBase):
|
|
| 132 |
choices = root.findall('.//Choice')
|
| 133 |
return [choice.get('value') for choice in choices]
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
|
| 136 |
"""Generate predictions using T5 model"""
|
| 137 |
logger.info("Received prediction request")
|
| 138 |
logger.info(f"Tasks: {json.dumps(tasks, indent=2)}")
|
| 139 |
-
logger.info(f"Context: {json.dumps(context, indent=2) if context else None}")
|
| 140 |
-
logger.info(f"Additional kwargs: {kwargs}")
|
| 141 |
|
| 142 |
predictions = []
|
| 143 |
-
# Get
|
| 144 |
-
valid_choices = []
|
| 145 |
try:
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
valid_choices = [choice.get('value') for choice in choices]
|
| 150 |
logger.info(f"Valid choices: {valid_choices}")
|
| 151 |
except Exception as e:
|
| 152 |
logger.error(f"Error parsing choices: {str(e)}")
|
| 153 |
-
valid_choices = ["no_category"]
|
|
|
|
| 154 |
|
| 155 |
try:
|
| 156 |
for task in tasks:
|
| 157 |
-
logger.info(f"Processing task: {json.dumps(task, indent=2)}")
|
| 158 |
-
|
| 159 |
input_text = task['data'].get(self.to_name)
|
| 160 |
-
logger.info(f"Input text: {input_text}")
|
| 161 |
-
logger.info(f"Using to_name: {self.to_name}")
|
| 162 |
-
|
| 163 |
if not input_text:
|
| 164 |
logger.warning(f"No input text found using {self.to_name}")
|
| 165 |
-
logger.warning(f"Available fields in task data: {list(task['data'].keys())}")
|
| 166 |
continue
|
| 167 |
|
| 168 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
inputs = self.tokenizer(
|
| 170 |
-
|
| 171 |
return_tensors="pt",
|
| 172 |
max_length=self.max_length,
|
| 173 |
truncation=True,
|
|
|
|
| 132 |
choices = root.findall('.//Choice')
|
| 133 |
return [choice.get('value') for choice in choices]
|
| 134 |
|
| 135 |
+
def get_categories_with_hints(self, label_config):
|
| 136 |
+
"""Extract categories and their hints from label config"""
|
| 137 |
+
import xml.etree.ElementTree as ET
|
| 138 |
+
|
| 139 |
+
root = ET.fromstring(label_config)
|
| 140 |
+
choices = root.findall('.//Choice')
|
| 141 |
+
categories = []
|
| 142 |
+
for choice in choices:
|
| 143 |
+
categories.append({
|
| 144 |
+
'value': choice.get('value'),
|
| 145 |
+
'hint': choice.get('hint')
|
| 146 |
+
})
|
| 147 |
+
return categories
|
| 148 |
+
|
| 149 |
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
|
| 150 |
"""Generate predictions using T5 model"""
|
| 151 |
logger.info("Received prediction request")
|
| 152 |
logger.info(f"Tasks: {json.dumps(tasks, indent=2)}")
|
|
|
|
|
|
|
| 153 |
|
| 154 |
predictions = []
|
| 155 |
+
# Get categories with their descriptions
|
|
|
|
| 156 |
try:
|
| 157 |
+
categories = self.get_categories_with_hints(self.label_config)
|
| 158 |
+
valid_choices = [cat['value'] for cat in categories]
|
| 159 |
+
category_descriptions = [f"{cat['value']}: {cat['hint']}" for cat in categories]
|
|
|
|
| 160 |
logger.info(f"Valid choices: {valid_choices}")
|
| 161 |
except Exception as e:
|
| 162 |
logger.error(f"Error parsing choices: {str(e)}")
|
| 163 |
+
valid_choices = ["no_category"]
|
| 164 |
+
category_descriptions = ["no_category: Default category when no others apply"]
|
| 165 |
|
| 166 |
try:
|
| 167 |
for task in tasks:
|
|
|
|
|
|
|
| 168 |
input_text = task['data'].get(self.to_name)
|
|
|
|
|
|
|
|
|
|
| 169 |
if not input_text:
|
| 170 |
logger.warning(f"No input text found using {self.to_name}")
|
|
|
|
| 171 |
continue
|
| 172 |
|
| 173 |
+
# Format prompt with input text and category descriptions
|
| 174 |
+
prompt = f"""Classify the following text into exactly one category.
|
| 175 |
+
|
| 176 |
+
Available categories with descriptions:
|
| 177 |
+
{chr(10).join(f"- {desc}" for desc in category_descriptions)}
|
| 178 |
+
|
| 179 |
+
Text to classify: {input_text}
|
| 180 |
+
|
| 181 |
+
Instructions:
|
| 182 |
+
1. Consider the text carefully
|
| 183 |
+
2. Choose the most appropriate category from the list
|
| 184 |
+
3. Return ONLY the category value (e.g. 'business_and_career', 'date', etc.)
|
| 185 |
+
4. Do not add any explanations or additional text
|
| 186 |
+
|
| 187 |
+
Category:"""
|
| 188 |
+
|
| 189 |
+
logger.info(f"Generated prompt: {prompt}")
|
| 190 |
+
|
| 191 |
+
# Generate prediction with prompt
|
| 192 |
inputs = self.tokenizer(
|
| 193 |
+
prompt,
|
| 194 |
return_tensors="pt",
|
| 195 |
max_length=self.max_length,
|
| 196 |
truncation=True,
|