File size: 10,083 Bytes
2f74bd4
 
 
 
 
 
 
 
 
 
da2b5de
2f74bd4
da2b5de
2f74bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da2b5de
 
 
 
 
 
2f74bd4
 
 
 
 
 
 
 
da2b5de
 
 
1901d9d
2f74bd4
da2b5de
2f74bd4
da2b5de
 
2f74bd4
da2b5de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f74bd4
da2b5de
 
 
 
2f74bd4
 
da2b5de
 
 
 
 
 
 
 
 
 
 
 
 
 
2f74bd4
 
 
da2b5de
2f74bd4
da2b5de
2f74bd4
da2b5de
 
 
2f74bd4
 
da2b5de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f74bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da2b5de
2f74bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
da2b5de
 
2f74bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1901d9d
2f74bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da2b5de
2f74bd4
da2b5de
 
2f74bd4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import os
import json
import warnings
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM

# ---------- CONFIG ----------
os.environ.setdefault("GRADIO_SERVER_PORT", "7860")
MODEL_PATH = "iqasimz/g2"  # <- change to your repo or local dir
MAX_NEW_TOKENS_DEFAULT = 300
TEMPERATURE_DEFAULT = 0
TOP_P_DEFAULT = 1.0
# ---------------------------

warnings.filterwarnings("ignore", module="torch")
_model_cache = {}

def _ensure_pad_token(tokenizer):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def load_model_to_cpu(model_dir: str):
    """Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU."""
    if model_dir in _model_cache:
        return _model_cache[model_dir]

    tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    tok = _ensure_pad_token(tok)

    mdl = AutoModelForCausalLM.from_pretrained(
        model_dir,
        trust_remote_code=True,
        torch_dtype=torch.float16,   # model runs in fp16 when moved to GPU
        device_map=None,             # keep on CPU for caching
    )
    mdl.eval()
    _model_cache[model_dir] = (tok, mdl)
    print(f"[cache] Loaded {model_dir} on CPU")
    return tok, mdl

def build_inference_prompt(paragraph: str) -> str:
    # Match your training format EXACTLY (Task + Rules + Paragraph in user turn)
    task_block = """Task: ou are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert.Number the sentences in the paragraph and tag the role of each one.\n
Rules:\n
- Do NOT change the text of any sentence.\n
- Keep the original order.\n
- Output exactly N lines, one per sentence.\n
- Each line must be: "<index> <original sentence> <role>", where role ∈ {claim, premise, none}.\n
- Do not add any explanations or extra text after the Nth line.
"""
    # Chat-style formatting used during training
    return (
        f"<|im_start|>user\n{task_block}\nParagraph:\n{paragraph}"
        f"<|im_end|>\n<|im_start|>assistant\n"
    )

def get_last_five_words(text: str) -> str:
    """Get the last 5 words from a text string."""
    words = text.strip().split()
    return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words)

def extract_role_from_suffix(text_after_match: str) -> str:
    """
    Extract role (claim, premise, none) from text after the 5-word match.
    Handles cases like 'claimabcd' -> 'claim'
    """
    text_after_match = text_after_match.strip()
    
    # Look for the role words at the start of the remaining text
    role_words = ['claim', 'premise', 'none']
    for role in role_words:
        if text_after_match.lower().startswith(role.lower()):
            return role
    
    # If no exact match, return the first word (fallback)
    first_word = text_after_match.split()[0] if text_after_match.split() else ""
    for role in role_words:
        if first_word.lower().startswith(role.lower()):
            return role
    
    return "none"  # default fallback

def parse_numbered_lines(text: str, original_paragraph: str):
    """
    Enhanced parsing with improved stopping criteria:
    1. Find exact match of last 5 words from input paragraph
    2. Look for role word after a space following the match
    3. Stop parsing after finding the last sentence to avoid gibberish
    """
    results = []
    lines = text.splitlines()
    
    # Get sentences from original paragraph for reference
    import re
    sentences = re.split(r'[.!?]+', original_paragraph.strip())
    sentences = [s.strip() for s in sentences if s.strip()]
    
    if not sentences:
        return results
    
    # Get last 5 words of the original paragraph
    last_five_words = get_last_five_words(original_paragraph)
    
    for line in lines:
        line = line.strip()
        if not line or not line[0].isdigit():
            continue
            
        try:
            # Parse index
            space_after_idx = line.find(" ")
            if space_after_idx == -1:
                continue
                
            idx = int(line[:space_after_idx])
            rest = line[space_after_idx + 1:].rstrip()
            
            # Check if this line contains the last 5 words (indicating last sentence)
            if last_five_words.lower() in rest.lower():
                # Find the position of the last 5 words
                match_pos = rest.lower().find(last_five_words.lower())
                if match_pos != -1:
                    # Extract sentence (everything up to and including the match)
                    sentence_end = match_pos + len(last_five_words)
                    sent = rest[:sentence_end].strip()
                    
                    # Look for role after the match
                    text_after_match = rest[sentence_end:].strip()
                    role = "none"  # default
                    
                    if text_after_match:
                        # Skip any immediate punctuation/spaces and look for role
                        text_after_match = text_after_match.lstrip(' .,!?')
                        role = extract_role_from_suffix(text_after_match)
                    
                    results.append({"index": idx, "sentence": sent, "role": role})
                    
                    # STOP parsing here - this is the last sentence
                    break
            else:
                # Regular parsing for non-last sentences
                last_space = rest.rfind(" ")
                if last_space == -1:
                    continue
                    
                sent = rest[:last_space].strip()
                role_candidate = rest[last_space + 1:].strip().lower()
                
                # Clean role (handle gibberish suffixes)
                role = "none"
                for valid_role in ['claim', 'premise', 'none']:
                    if role_candidate.startswith(valid_role):
                        role = valid_role
                        break
                
                results.append({"index": idx, "sentence": sent, "role": role})
                
        except Exception as e:
            print(f"Error parsing line '{line}': {e}")
            continue
    
    return results

@spaces.GPU(duration=120)
def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: float, show_parsed: bool):
    paragraph = (paragraph or "").strip()
    if not paragraph:
        return "Please paste a paragraph.", ""

    tokenizer, model = load_model_to_cpu(MODEL_PATH)
    model = model.to("cuda")

    prompt = build_inference_prompt(paragraph)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.inference_mode():
        output = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            temperature=float(temperature),
            top_p=float(top_p),
            do_sample=(temperature > 0.0 and top_p < 1.0),
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )

    full = tokenizer.decode(output[0], skip_special_tokens=False)

    # Extract assistant segment
    if "<|im_start|>assistant\n" in full:
        resp = full.split("<|im_start|>assistant\n")[-1]
        resp = resp.split("<|im_end|>")[0].strip()
    else:
        resp = full.strip()

    # Updated parsing with original paragraph reference
    parsed = parse_numbered_lines(resp, paragraph)
    parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
    return resp, parsed_json

def launch_app():
    with gr.Blocks(title="Argument Role Tagger (DeepSeek 1.5B + LoRA merged)") as demo:
        gr.Markdown("## Argument Role Tagger")
        gr.Markdown(
            "Paste a paragraph. The model will number sentences and label each as **claim**, **premise**, or **none**."
        )

        with gr.Row():
            with gr.Column(scale=2):
                paragraph = gr.Textbox(
                    label="Paragraph",
                    lines=10,
                    placeholder="Paste your paragraph…",
                    value=("Governments should subsidize solar panels to accelerate clean energy adoption. "
                           "Lowering installation costs would encourage more households to switch, reducing fossil fuel dependence. "
                           "In the long run, this shift could stabilize energy prices and reduce environmental damage.")
                )
                with gr.Row():
                    max_new_tokens = gr.Slider(200, 4300, value=MAX_NEW_TOKENS_DEFAULT, step=100, label="Max new tokens")
                with gr.Row():
                    temperature = gr.Slider(0.0, 1.0, value=TEMPERATURE_DEFAULT, step=0.05, label="Temperature")
                    top_p = gr.Slider(0.5, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
                show_parsed = gr.Checkbox(value=True, label="Show parsed JSON")
                run_btn = gr.Button("Analyze", variant="primary")

            with gr.Column(scale=3):
                raw_out = gr.Textbox(label="Model Output (raw)", lines=18, show_copy_button=True)
                parsed_out = gr.Code(label="Parsed JSON", language="json")

        run_btn.click(
            analyze,
            inputs=[paragraph, max_new_tokens, temperature, top_p, show_parsed],
            outputs=[raw_out, parsed_out],
        )

        gr.Markdown("### Tips")
        gr.Markdown("- Set MODEL_PATH at the top to your merged model repo or local path.\n"
                    "- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n"
                    "- Your training format (chat tokens + Task/Rules) is preserved in the prompt.\n"
                    "- **Enhanced parsing**: Stops at last sentence using 5-word match to avoid gibberish.")

    return demo

if __name__ == "__main__":
    app = launch_app()
    app.launch(share=True)