Anisha Bhatnagar commited on
Commit
2194877
·
1 Parent(s): 8c133f5

added structered response generation as openai was truncating feature names

Browse files
Files changed (1) hide show
  1. utils/llm_feat_utils.py +24 -11
utils/llm_feat_utils.py CHANGED
@@ -32,19 +32,20 @@ def generate_feature_spans(client, text: str, features: list[str]) -> str:
32
  """
33
  Call to OpenAI to extract spans. Returns a JSON string.
34
  """
 
 
 
35
  prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
36
 
37
  Important:
38
  - The headers like "Document 1:" etc are NOT part of the original text — ignore them.
39
  - For each feature, even if there is no match, return an empty list.
40
  - Only return exact phrases from the text.
 
41
 
42
- Respond in JSON format like:
43
- {{
44
- "feature1": ["span1", "span2"],
45
- "feature2": [],
46
-
47
- }}
48
 
49
  Text:
50
  \"\"\"{text}\"\"\"
@@ -52,9 +53,9 @@ def generate_feature_spans(client, text: str, features: list[str]) -> str:
52
  Style Features:
53
  {features}
54
  """
55
- print('==================>>>>>>>>>>')
56
- print(prompt)
57
- print('==================>>>>>>>>>>')
58
  response = client.chat.completions.create(
59
  model="gpt-4o",
60
  messages=[{"role":"user","content":prompt}]
@@ -71,8 +72,14 @@ def generate_feature_spans_with_retries(client, text: str, features: list[str])
71
  for attempt in range(MAX_ATTEMPTS):
72
  try:
73
  response_str = generate_feature_spans(client, text, features)
74
- print(response_str)
75
  result = json.loads(response_str)
 
 
 
 
 
 
76
  return result
77
  except (JSONDecodeError, ValueError) as e:
78
  print(f"Attempt {attempt+1} failed: {e}")
@@ -116,7 +123,13 @@ def generate_feature_spans_cached(client, text: str, features: list[str], role:
116
  if h in cache:
117
  # print(f"Found feature: {feat}")
118
  found_feats_count += 1
119
- result[feat] = cache[h]["spans"]
 
 
 
 
 
 
120
  else:
121
  # print(f"Missing feature: {feat}")
122
  missing_feats_count += 1
 
32
  """
33
  Call to OpenAI to extract spans. Returns a JSON string.
34
  """
35
+ # For some of the longer features, openai client was truncating the feature names, resulting in downstream errors.
36
+ # Adding structured JSON template to ensure all features are included properly.
37
+ features_json_template = {feature: [] for feature in features}
38
  prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
39
 
40
  Important:
41
  - The headers like "Document 1:" etc are NOT part of the original text — ignore them.
42
  - For each feature, even if there is no match, return an empty list.
43
  - Only return exact phrases from the text.
44
+ - Use the EXACT feature names as JSON keys - do not paraphrase or shorten them.
45
 
46
+
47
+ Respond in this EXACT JSON format (use these exact keys, populate the lists with the extracted text spans):
48
+ {json.dumps(features_json_template, indent=2)}
 
 
 
49
 
50
  Text:
51
  \"\"\"{text}\"\"\"
 
53
  Style Features:
54
  {features}
55
  """
56
+ # print('==================>>>>>>>>>>')
57
+ # print(prompt)
58
+ # print('==================>>>>>>>>>>')
59
  response = client.chat.completions.create(
60
  model="gpt-4o",
61
  messages=[{"role":"user","content":prompt}]
 
72
  for attempt in range(MAX_ATTEMPTS):
73
  try:
74
  response_str = generate_feature_spans(client, text, features)
75
+ # print(response_str)
76
  result = json.loads(response_str)
77
+ # Additional check to ensure all requested features are present in the response correctly
78
+ if result.keys() != set(features):
79
+ print("Response keys do not match requested features. Retrying!")
80
+ response_str = generate_feature_spans(client, text, features)
81
+ # print(response_str)
82
+ result = json.loads(response_str)
83
  return result
84
  except (JSONDecodeError, ValueError) as e:
85
  print(f"Attempt {attempt+1} failed: {e}")
 
123
  if h in cache:
124
  # print(f"Found feature: {feat}")
125
  found_feats_count += 1
126
+ if cache[h]["spans"] is None:
127
+ print(f"Missing feature: {feat}")
128
+ missing_feats_count += 1
129
+ missing_feats.append(feat)
130
+ else:
131
+ result[feat] = cache[h]["spans"]
132
+
133
  else:
134
  # print(f"Missing feature: {feat}")
135
  missing_feats_count += 1