iqasimz commited on
Commit
2f74bd4
·
verified ·
1 Parent(s): 716ad32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py CHANGED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import warnings
5
+ import re
6
+ import torch
7
+ import gradio as gr
8
+ import spaces
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from transformers import StoppingCriteria, StoppingCriteriaList
11
+
12
+ # ---------- CONFIG ----------
13
+ os.environ.setdefault("GRADIO_SERVER_PORT", "7860")
14
+ MODEL_PATH = "iqasimz/g1" # <- change to your repo or local dir
15
+ MAX_NEW_TOKENS_DEFAULT = 300
16
+ TEMPERATURE_DEFAULT = 0.2
17
+ TOP_P_DEFAULT = 1.0
18
+ # ---------------------------
19
+
20
+ warnings.filterwarnings("ignore", module="torch")
21
+ _model_cache = {}
22
+
23
+ def _ensure_pad_token(tokenizer):
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+ return tokenizer
27
+
28
+ def load_model_to_cpu(model_dir: str):
29
+ """Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU."""
30
+ if model_dir in _model_cache:
31
+ return _model_cache[model_dir]
32
+
33
+ tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
34
+ tok = _ensure_pad_token(tok)
35
+
36
+ mdl = AutoModelForCausalLM.from_pretrained(
37
+ model_dir,
38
+ trust_remote_code=True,
39
+ torch_dtype=torch.float16, # model runs in fp16 when moved to GPU
40
+ device_map=None, # keep on CPU for caching
41
+ )
42
+ mdl.eval()
43
+ _model_cache[model_dir] = (tok, mdl)
44
+ print(f"[cache] Loaded {model_dir} on CPU")
45
+ return tok, mdl
46
+
47
+ def build_inference_prompt(paragraph: str) -> str:
48
+ # Match your training format EXACTLY (Task + Rules + Paragraph in user turn)
49
+ task_block = """Task: You are an expert argument analyst. Number the sentences in the paragraph and tag the role of each one.
50
+ Rules:
51
+ - Do NOT change the text of any sentence.
52
+ - Keep the original order.
53
+ - Output exactly N lines, one per sentence.
54
+ - Each line must be: "<index> <original sentence> <role>", where role ∈ {claim, premise, none}.
55
+ - Do not add any explanations or extra text after the Nth line.
56
+ """
57
+ # Chat-style formatting used during training
58
+ return (
59
+ f"<|im_start|>user\n{task_block}\nParagraph:\n{paragraph}"
60
+ f"<|im_end|>\n<|im_start|>assistant\n"
61
+ )
62
+
63
+ # -------- Sentence counting for N --------
64
+ SENT_SPLIT_RE = re.compile(r'(?<!\b[A-Z])(?<=[.!?])\s+(?=\S)')
65
+
66
+ def count_sentences(paragraph: str) -> int:
67
+ p = (paragraph or "").strip()
68
+ if not p:
69
+ return 0
70
+ parts = [s.strip() for s in SENT_SPLIT_RE.split(p) if s.strip()]
71
+ return max(1, len(parts))
72
+
73
+ # -------- Stopping criteria to halt after N labeled lines --------
74
+ class RoleLinesStop(StoppingCriteria):
75
+ """
76
+ Stop when we've generated N lines that look like:
77
+ <index> <original sentence> <role>
78
+ with role ∈ {claim, premise, none}.
79
+ Also stops if the model begins line N+1 (e.g., "N+1 ").
80
+ """
81
+ def __init__(self, tokenizer, prompt_len: int, n_lines: int):
82
+ self.tok = tokenizer
83
+ self.prompt_len = prompt_len
84
+ self.n_lines = n_lines
85
+ self.role_line_re = re.compile(
86
+ r'^\s*\d+\s+.+\s+(?:claim|premise|none)\s*$', re.IGNORECASE | re.MULTILINE
87
+ )
88
+ self.next_index_re = re.compile(rf'^\s*{n_lines+1}\s', re.MULTILINE) if n_lines >= 1 else None
89
+
90
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
91
+ gen_ids = input_ids[0, self.prompt_len:]
92
+ if gen_ids.numel() == 0:
93
+ return False
94
+ text = self.tok.decode(gen_ids, skip_special_tokens=True)
95
+
96
+ # If we see the start of line N+1, stop immediately
97
+ if self.next_index_re and self.next_index_re.search(text):
98
+ return True
99
+
100
+ # Count complete role-tagged lines
101
+ complete_lines = self.role_line_re.findall(text)
102
+ return len(complete_lines) >= self.n_lines
103
+
104
+ def parse_numbered_lines(text: str):
105
+ """
106
+ Optional: parse lines like:
107
+ 1 Some sentence. claim
108
+ 2 Another sentence. premise
109
+ into a list of dicts.
110
+ """
111
+ results = []
112
+ for line in text.splitlines():
113
+ line = line.strip()
114
+ if not line or not line[0].isdigit():
115
+ continue
116
+ try:
117
+ # index first
118
+ space_after_idx = line.find(" ")
119
+ idx = int(line[:space_after_idx])
120
+ rest = line[space_after_idx + 1:].rstrip()
121
+ # last space => role
122
+ last_space = rest.rfind(" ")
123
+ sent = rest[:last_space].strip()
124
+ role = rest[last_space + 1:].strip().lower()
125
+ results.append({"index": idx, "sentence": sent, "role": role})
126
+ except Exception:
127
+ pass
128
+ return results
129
+
130
+ @spaces.GPU(duration=120)
131
+ def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: float, show_parsed: bool):
132
+ paragraph = (paragraph or "").strip()
133
+ if not paragraph:
134
+ return "Please paste a paragraph.", ""
135
+
136
+ tokenizer, model = load_model_to_cpu(MODEL_PATH)
137
+ model = model.to("cuda")
138
+
139
+ prompt = build_inference_prompt(paragraph)
140
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
141
+
142
+ # Compute target number of lines (N) and install stopping criteria
143
+ n_lines = count_sentences(paragraph)
144
+ stopper = RoleLinesStop(
145
+ tokenizer=tokenizer,
146
+ prompt_len=inputs["input_ids"].shape[1],
147
+ n_lines=n_lines
148
+ )
149
+ stops = StoppingCriteriaList([stopper])
150
+
151
+ with torch.inference_mode():
152
+ output = model.generate(
153
+ **inputs,
154
+ max_new_tokens=int(max_new_tokens),
155
+ temperature=float(temperature),
156
+ top_p=float(top_p),
157
+ do_sample=(float(temperature) > 0.0), # sampling only if temp > 0
158
+ pad_token_id=tokenizer.pad_token_id,
159
+ eos_token_id=tokenizer.eos_token_id,
160
+ use_cache=True,
161
+ stopping_criteria=stops,
162
+ )
163
+
164
+ full = tokenizer.decode(output[0], skip_special_tokens=False)
165
+
166
+ # Extract assistant segment
167
+ if "<|im_start|>assistant\n" in full:
168
+ resp = full.split("<|im_start|>assistant\n")[-1]
169
+ resp = resp.split("<|im_end|>")[0].strip()
170
+ else:
171
+ resp = full.strip()
172
+
173
+ # Safety net: hard-trim to exactly N labeled lines if model leaked extras
174
+ role_line_re = re.compile(r'^\s*\d+\s+.+\s+(?:claim|premise|none)\s*$', re.IGNORECASE | re.MULTILINE)
175
+ matched = role_line_re.findall(resp)
176
+ if matched:
177
+ trimmed = "\n".join(matched[:n_lines]).strip()
178
+ if trimmed:
179
+ resp = trimmed
180
+
181
+ parsed = parse_numbered_lines(resp)
182
+ parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
183
+ return resp, parsed_json
184
+
185
+ def launch_app():
186
+ with gr.Blocks(title="Argument Role Tagger (DeepSeek 1.5B + LoRA merged)") as demo:
187
+ gr.Markdown("## Argument Role Tagger")
188
+ gr.Markdown(
189
+ "Paste a paragraph. The model will number sentences and label each as **claim**, **premise**, or **none**."
190
+ )
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=2):
194
+ paragraph = gr.Textbox(
195
+ label="Paragraph",
196
+ lines=10,
197
+ placeholder="Paste your paragraph…",
198
+ value=("Governments should subsidize solar panels to accelerate clean energy adoption. "
199
+ "Lowering installation costs would encourage more households to switch, reducing fossil fuel dependence. "
200
+ "In the long run, this shift could stabilize energy prices and reduce environmental damage.")
201
+ )
202
+ with gr.Row():
203
+ max_new_tokens = gr.Slider(64, 1024, value=MAX_NEW_TOKENS_DEFAULT, step=16, label="Max new tokens")
204
+ with gr.Row():
205
+ temperature = gr.Slider(0.0, 1.0, value=TEMPERATURE_DEFAULT, step=0.05, label="Temperature")
206
+ top_p = gr.Slider(0.5, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
207
+ show_parsed = gr.Checkbox(value=True, label="Show parsed JSON")
208
+ run_btn = gr.Button("Analyze", variant="primary")
209
+
210
+ with gr.Column(scale=3):
211
+ raw_out = gr.Textbox(label="Model Output (raw)", lines=18, show_copy_button=True)
212
+ parsed_out = gr.Code(label="Parsed JSON", language="json")
213
+
214
+ run_btn.click(
215
+ analyze,
216
+ inputs=[paragraph, max_new_tokens, temperature, top_p, show_parsed],
217
+ outputs=[raw_out, parsed_out],
218
+ )
219
+
220
+ gr.Markdown("### Tips")
221
+ gr.Markdown("- Set `MODEL_PATH` at the top to your merged model repo or local path.\n"
222
+ "- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n"
223
+ "- Output is forcibly stopped after exactly N lines.")
224
+
225
+ return demo
226
+
227
+ if __name__ == "__main__":
228
+ app = launch_app()
229
+ app.launch(share=True)