iqasimz commited on
Commit
da2b5de
·
verified ·
1 Parent(s): 2f74bd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -85
app.py CHANGED
@@ -1,19 +1,16 @@
1
- # app.py
2
  import os
3
  import json
4
  import warnings
5
- import re
6
  import torch
7
  import gradio as gr
8
  import spaces
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
- from transformers import StoppingCriteria, StoppingCriteriaList
11
 
12
  # ---------- CONFIG ----------
13
  os.environ.setdefault("GRADIO_SERVER_PORT", "7860")
14
- MODEL_PATH = "iqasimz/g1" # <- change to your repo or local dir
15
  MAX_NEW_TOKENS_DEFAULT = 300
16
- TEMPERATURE_DEFAULT = 0.2
17
  TOP_P_DEFAULT = 1.0
18
  # ---------------------------
19
 
@@ -46,12 +43,12 @@ def load_model_to_cpu(model_dir: str):
46
 
47
  def build_inference_prompt(paragraph: str) -> str:
48
  # Match your training format EXACTLY (Task + Rules + Paragraph in user turn)
49
- task_block = """Task: You are an expert argument analyst. Number the sentences in the paragraph and tag the role of each one.
50
- Rules:
51
- - Do NOT change the text of any sentence.
52
- - Keep the original order.
53
- - Output exactly N lines, one per sentence.
54
- - Each line must be: "<index> <original sentence> <role>", where role ∈ {claim, premise, none}.
55
  - Do not add any explanations or extra text after the Nth line.
56
  """
57
  # Chat-style formatting used during training
@@ -60,71 +57,111 @@ Rules:
60
  f"<|im_end|>\n<|im_start|>assistant\n"
61
  )
62
 
63
- # -------- Sentence counting for N --------
64
- SENT_SPLIT_RE = re.compile(r'(?<!\b[A-Z])(?<=[.!?])\s+(?=\S)')
 
 
65
 
66
- def count_sentences(paragraph: str) -> int:
67
- p = (paragraph or "").strip()
68
- if not p:
69
- return 0
70
- parts = [s.strip() for s in SENT_SPLIT_RE.split(p) if s.strip()]
71
- return max(1, len(parts))
72
-
73
- # -------- Stopping criteria to halt after N labeled lines --------
74
- class RoleLinesStop(StoppingCriteria):
75
  """
76
- Stop when we've generated N lines that look like:
77
- <index> <original sentence> <role>
78
- with role ∈ {claim, premise, none}.
79
- Also stops if the model begins line N+1 (e.g., "N+1 ").
80
  """
81
- def __init__(self, tokenizer, prompt_len: int, n_lines: int):
82
- self.tok = tokenizer
83
- self.prompt_len = prompt_len
84
- self.n_lines = n_lines
85
- self.role_line_re = re.compile(
86
- r'^\s*\d+\s+.+\s+(?:claim|premise|none)\s*$', re.IGNORECASE | re.MULTILINE
87
- )
88
- self.next_index_re = re.compile(rf'^\s*{n_lines+1}\s', re.MULTILINE) if n_lines >= 1 else None
89
-
90
- def __call__(self, input_ids, scores, **kwargs) -> bool:
91
- gen_ids = input_ids[0, self.prompt_len:]
92
- if gen_ids.numel() == 0:
93
- return False
94
- text = self.tok.decode(gen_ids, skip_special_tokens=True)
95
-
96
- # If we see the start of line N+1, stop immediately
97
- if self.next_index_re and self.next_index_re.search(text):
98
- return True
99
-
100
- # Count complete role-tagged lines
101
- complete_lines = self.role_line_re.findall(text)
102
- return len(complete_lines) >= self.n_lines
103
-
104
- def parse_numbered_lines(text: str):
105
  """
106
- Optional: parse lines like:
107
- 1 Some sentence. claim
108
- 2 Another sentence. premise
109
- into a list of dicts.
110
  """
111
  results = []
112
- for line in text.splitlines():
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  line = line.strip()
114
  if not line or not line[0].isdigit():
115
  continue
 
116
  try:
117
- # index first
118
  space_after_idx = line.find(" ")
 
 
 
119
  idx = int(line[:space_after_idx])
120
  rest = line[space_after_idx + 1:].rstrip()
121
- # last space => role
122
- last_space = rest.rfind(" ")
123
- sent = rest[:last_space].strip()
124
- role = rest[last_space + 1:].strip().lower()
125
- results.append({"index": idx, "sentence": sent, "role": role})
126
- except Exception:
127
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return results
129
 
130
  @spaces.GPU(duration=120)
@@ -139,26 +176,16 @@ def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: floa
139
  prompt = build_inference_prompt(paragraph)
140
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
141
 
142
- # Compute target number of lines (N) and install stopping criteria
143
- n_lines = count_sentences(paragraph)
144
- stopper = RoleLinesStop(
145
- tokenizer=tokenizer,
146
- prompt_len=inputs["input_ids"].shape[1],
147
- n_lines=n_lines
148
- )
149
- stops = StoppingCriteriaList([stopper])
150
-
151
  with torch.inference_mode():
152
  output = model.generate(
153
  **inputs,
154
  max_new_tokens=int(max_new_tokens),
155
  temperature=float(temperature),
156
  top_p=float(top_p),
157
- do_sample=(float(temperature) > 0.0), # sampling only if temp > 0
158
  pad_token_id=tokenizer.pad_token_id,
159
  eos_token_id=tokenizer.eos_token_id,
160
  use_cache=True,
161
- stopping_criteria=stops,
162
  )
163
 
164
  full = tokenizer.decode(output[0], skip_special_tokens=False)
@@ -170,15 +197,8 @@ def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: floa
170
  else:
171
  resp = full.strip()
172
 
173
- # Safety net: hard-trim to exactly N labeled lines if model leaked extras
174
- role_line_re = re.compile(r'^\s*\d+\s+.+\s+(?:claim|premise|none)\s*$', re.IGNORECASE | re.MULTILINE)
175
- matched = role_line_re.findall(resp)
176
- if matched:
177
- trimmed = "\n".join(matched[:n_lines]).strip()
178
- if trimmed:
179
- resp = trimmed
180
-
181
- parsed = parse_numbered_lines(resp)
182
  parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
183
  return resp, parsed_json
184
 
@@ -218,9 +238,10 @@ def launch_app():
218
  )
219
 
220
  gr.Markdown("### Tips")
221
- gr.Markdown("- Set `MODEL_PATH` at the top to your merged model repo or local path.\n"
222
  "- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n"
223
- "- Output is forcibly stopped after exactly N lines.")
 
224
 
225
  return demo
226
 
 
 
1
  import os
2
  import json
3
  import warnings
 
4
  import torch
5
  import gradio as gr
6
  import spaces
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
8
 
9
  # ---------- CONFIG ----------
10
  os.environ.setdefault("GRADIO_SERVER_PORT", "7860")
11
+ MODEL_PATH = "iqasimz/g2" # <- change to your repo or local dir
12
  MAX_NEW_TOKENS_DEFAULT = 300
13
+ TEMPERATURE_DEFAULT = 0
14
  TOP_P_DEFAULT = 1.0
15
  # ---------------------------
16
 
 
43
 
44
  def build_inference_prompt(paragraph: str) -> str:
45
  # Match your training format EXACTLY (Task + Rules + Paragraph in user turn)
46
+ task_block = """Task: ou are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert.Number the sentences in the paragraph and tag the role of each one.\n
47
+ Rules:\n
48
+ - Do NOT change the text of any sentence.\n
49
+ - Keep the original order.\n
50
+ - Output exactly N lines, one per sentence.\n
51
+ - Each line must be: "<index> <original sentence> <role>", where role ∈ {claim, premise, none}.\n
52
  - Do not add any explanations or extra text after the Nth line.
53
  """
54
  # Chat-style formatting used during training
 
57
  f"<|im_end|>\n<|im_start|>assistant\n"
58
  )
59
 
60
+ def get_last_five_words(text: str) -> str:
61
+ """Get the last 5 words from a text string."""
62
+ words = text.strip().split()
63
+ return " ".join(words[-5:]) if len(words) >= 5 else " ".join(words)
64
 
65
+ def extract_role_from_suffix(text_after_match: str) -> str:
 
 
 
 
 
 
 
 
66
  """
67
+ Extract role (claim, premise, none) from text after the 5-word match.
68
+ Handles cases like 'claimabcd' -> 'claim'
 
 
69
  """
70
+ text_after_match = text_after_match.strip()
71
+
72
+ # Look for the role words at the start of the remaining text
73
+ role_words = ['claim', 'premise', 'none']
74
+ for role in role_words:
75
+ if text_after_match.lower().startswith(role.lower()):
76
+ return role
77
+
78
+ # If no exact match, return the first word (fallback)
79
+ first_word = text_after_match.split()[0] if text_after_match.split() else ""
80
+ for role in role_words:
81
+ if first_word.lower().startswith(role.lower()):
82
+ return role
83
+
84
+ return "none" # default fallback
85
+
86
+ def parse_numbered_lines(text: str, original_paragraph: str):
 
 
 
 
 
 
 
87
  """
88
+ Enhanced parsing with improved stopping criteria:
89
+ 1. Find exact match of last 5 words from input paragraph
90
+ 2. Look for role word after a space following the match
91
+ 3. Stop parsing after finding the last sentence to avoid gibberish
92
  """
93
  results = []
94
+ lines = text.splitlines()
95
+
96
+ # Get sentences from original paragraph for reference
97
+ import re
98
+ sentences = re.split(r'[.!?]+', original_paragraph.strip())
99
+ sentences = [s.strip() for s in sentences if s.strip()]
100
+
101
+ if not sentences:
102
+ return results
103
+
104
+ # Get last 5 words of the original paragraph
105
+ last_five_words = get_last_five_words(original_paragraph)
106
+
107
+ for line in lines:
108
  line = line.strip()
109
  if not line or not line[0].isdigit():
110
  continue
111
+
112
  try:
113
+ # Parse index
114
  space_after_idx = line.find(" ")
115
+ if space_after_idx == -1:
116
+ continue
117
+
118
  idx = int(line[:space_after_idx])
119
  rest = line[space_after_idx + 1:].rstrip()
120
+
121
+ # Check if this line contains the last 5 words (indicating last sentence)
122
+ if last_five_words.lower() in rest.lower():
123
+ # Find the position of the last 5 words
124
+ match_pos = rest.lower().find(last_five_words.lower())
125
+ if match_pos != -1:
126
+ # Extract sentence (everything up to and including the match)
127
+ sentence_end = match_pos + len(last_five_words)
128
+ sent = rest[:sentence_end].strip()
129
+
130
+ # Look for role after the match
131
+ text_after_match = rest[sentence_end:].strip()
132
+ role = "none" # default
133
+
134
+ if text_after_match:
135
+ # Skip any immediate punctuation/spaces and look for role
136
+ text_after_match = text_after_match.lstrip(' .,!?')
137
+ role = extract_role_from_suffix(text_after_match)
138
+
139
+ results.append({"index": idx, "sentence": sent, "role": role})
140
+
141
+ # STOP parsing here - this is the last sentence
142
+ break
143
+ else:
144
+ # Regular parsing for non-last sentences
145
+ last_space = rest.rfind(" ")
146
+ if last_space == -1:
147
+ continue
148
+
149
+ sent = rest[:last_space].strip()
150
+ role_candidate = rest[last_space + 1:].strip().lower()
151
+
152
+ # Clean role (handle gibberish suffixes)
153
+ role = "none"
154
+ for valid_role in ['claim', 'premise', 'none']:
155
+ if role_candidate.startswith(valid_role):
156
+ role = valid_role
157
+ break
158
+
159
+ results.append({"index": idx, "sentence": sent, "role": role})
160
+
161
+ except Exception as e:
162
+ print(f"Error parsing line '{line}': {e}")
163
+ continue
164
+
165
  return results
166
 
167
  @spaces.GPU(duration=120)
 
176
  prompt = build_inference_prompt(paragraph)
177
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
178
 
 
 
 
 
 
 
 
 
 
179
  with torch.inference_mode():
180
  output = model.generate(
181
  **inputs,
182
  max_new_tokens=int(max_new_tokens),
183
  temperature=float(temperature),
184
  top_p=float(top_p),
185
+ do_sample=(temperature > 0.0 and top_p < 1.0),
186
  pad_token_id=tokenizer.pad_token_id,
187
  eos_token_id=tokenizer.eos_token_id,
188
  use_cache=True,
 
189
  )
190
 
191
  full = tokenizer.decode(output[0], skip_special_tokens=False)
 
197
  else:
198
  resp = full.strip()
199
 
200
+ # Updated parsing with original paragraph reference
201
+ parsed = parse_numbered_lines(resp, paragraph)
 
 
 
 
 
 
 
202
  parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
203
  return resp, parsed_json
204
 
 
238
  )
239
 
240
  gr.Markdown("### Tips")
241
+ gr.Markdown("- Set MODEL_PATH at the top to your merged model repo or local path.\n"
242
  "- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n"
243
+ "- Your training format (chat tokens + Task/Rules) is preserved in the prompt.\n"
244
+ "- **Enhanced parsing**: Stops at last sentence using 5-word match to avoid gibberish.")
245
 
246
  return demo
247