File size: 10,083 Bytes
2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 1901d9d 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 1901d9d 2f74bd4 da2b5de 2f74bd4 da2b5de 2f74bd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | import os
import json
import warnings
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
# ---------- CONFIG ----------
os.environ.setdefault("GRADIO_SERVER_PORT", "7860")
MODEL_PATH = "iqasimz/g2" # <- change to your repo or local dir
MAX_NEW_TOKENS_DEFAULT = 300
TEMPERATURE_DEFAULT = 0
TOP_P_DEFAULT = 1.0
# ---------------------------
warnings.filterwarnings("ignore", module="torch")
_model_cache = {}
def _ensure_pad_token(tokenizer):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def load_model_to_cpu(model_dir: str):
"""Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU."""
if model_dir in _model_cache:
return _model_cache[model_dir]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tok = _ensure_pad_token(tok)
mdl = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.float16, # model runs in fp16 when moved to GPU
device_map=None, # keep on CPU for caching
)
mdl.eval()
_model_cache[model_dir] = (tok, mdl)
print(f"[cache] Loaded {model_dir} on CPU")
return tok, mdl
def build_inference_prompt(paragraph: str) -> str:
# Match your training format EXACTLY (Task + Rules + Paragraph in user turn)
task_block = """Task: ou are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert.Number the sentences in the paragraph and tag the role of each one.\n
Rules:\n
- Do NOT change the text of any sentence.\n
- Keep the original order.\n
- Output exactly N lines, one per sentence.\n
- Each line must be: "<index> <original sentence> <role>", where role ∈ {claim, premise, none}.\n
- Do not add any explanations or extra text after the Nth line.
"""
# Chat-style formatting used during training
return (
f"<|im_start|>user\n{task_block}\nParagraph:\n{paragraph}"
f"<|im_end|>\n<|im_start|>assistant\n"
)
def get_last_five_words(text: str) -> str:
"""Get the last 5 words from a text string."""
words = text.strip().split()
return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words)
def extract_role_from_suffix(text_after_match: str) -> str:
"""
Extract role (claim, premise, none) from text after the 5-word match.
Handles cases like 'claimabcd' -> 'claim'
"""
text_after_match = text_after_match.strip()
# Look for the role words at the start of the remaining text
role_words = ['claim', 'premise', 'none']
for role in role_words:
if text_after_match.lower().startswith(role.lower()):
return role
# If no exact match, return the first word (fallback)
first_word = text_after_match.split()[0] if text_after_match.split() else ""
for role in role_words:
if first_word.lower().startswith(role.lower()):
return role
return "none" # default fallback
def parse_numbered_lines(text: str, original_paragraph: str):
"""
Enhanced parsing with improved stopping criteria:
1. Find exact match of last 5 words from input paragraph
2. Look for role word after a space following the match
3. Stop parsing after finding the last sentence to avoid gibberish
"""
results = []
lines = text.splitlines()
# Get sentences from original paragraph for reference
import re
sentences = re.split(r'[.!?]+', original_paragraph.strip())
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return results
# Get last 5 words of the original paragraph
last_five_words = get_last_five_words(original_paragraph)
for line in lines:
line = line.strip()
if not line or not line[0].isdigit():
continue
try:
# Parse index
space_after_idx = line.find(" ")
if space_after_idx == -1:
continue
idx = int(line[:space_after_idx])
rest = line[space_after_idx + 1:].rstrip()
# Check if this line contains the last 5 words (indicating last sentence)
if last_five_words.lower() in rest.lower():
# Find the position of the last 5 words
match_pos = rest.lower().find(last_five_words.lower())
if match_pos != -1:
# Extract sentence (everything up to and including the match)
sentence_end = match_pos + len(last_five_words)
sent = rest[:sentence_end].strip()
# Look for role after the match
text_after_match = rest[sentence_end:].strip()
role = "none" # default
if text_after_match:
# Skip any immediate punctuation/spaces and look for role
text_after_match = text_after_match.lstrip(' .,!?')
role = extract_role_from_suffix(text_after_match)
results.append({"index": idx, "sentence": sent, "role": role})
# STOP parsing here - this is the last sentence
break
else:
# Regular parsing for non-last sentences
last_space = rest.rfind(" ")
if last_space == -1:
continue
sent = rest[:last_space].strip()
role_candidate = rest[last_space + 1:].strip().lower()
# Clean role (handle gibberish suffixes)
role = "none"
for valid_role in ['claim', 'premise', 'none']:
if role_candidate.startswith(valid_role):
role = valid_role
break
results.append({"index": idx, "sentence": sent, "role": role})
except Exception as e:
print(f"Error parsing line '{line}': {e}")
continue
return results
@spaces.GPU(duration=120)
def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: float, show_parsed: bool):
paragraph = (paragraph or "").strip()
if not paragraph:
return "Please paste a paragraph.", ""
tokenizer, model = load_model_to_cpu(MODEL_PATH)
model = model.to("cuda")
prompt = build_inference_prompt(paragraph)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=(temperature > 0.0 and top_p < 1.0),
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
full = tokenizer.decode(output[0], skip_special_tokens=False)
# Extract assistant segment
if "<|im_start|>assistant\n" in full:
resp = full.split("<|im_start|>assistant\n")[-1]
resp = resp.split("<|im_end|>")[0].strip()
else:
resp = full.strip()
# Updated parsing with original paragraph reference
parsed = parse_numbered_lines(resp, paragraph)
parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
return resp, parsed_json
def launch_app():
with gr.Blocks(title="Argument Role Tagger (DeepSeek 1.5B + LoRA merged)") as demo:
gr.Markdown("## Argument Role Tagger")
gr.Markdown(
"Paste a paragraph. The model will number sentences and label each as **claim**, **premise**, or **none**."
)
with gr.Row():
with gr.Column(scale=2):
paragraph = gr.Textbox(
label="Paragraph",
lines=10,
placeholder="Paste your paragraph…",
value=("Governments should subsidize solar panels to accelerate clean energy adoption. "
"Lowering installation costs would encourage more households to switch, reducing fossil fuel dependence. "
"In the long run, this shift could stabilize energy prices and reduce environmental damage.")
)
with gr.Row():
max_new_tokens = gr.Slider(200, 4300, value=MAX_NEW_TOKENS_DEFAULT, step=100, label="Max new tokens")
with gr.Row():
temperature = gr.Slider(0.0, 1.0, value=TEMPERATURE_DEFAULT, step=0.05, label="Temperature")
top_p = gr.Slider(0.5, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
show_parsed = gr.Checkbox(value=True, label="Show parsed JSON")
run_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=3):
raw_out = gr.Textbox(label="Model Output (raw)", lines=18, show_copy_button=True)
parsed_out = gr.Code(label="Parsed JSON", language="json")
run_btn.click(
analyze,
inputs=[paragraph, max_new_tokens, temperature, top_p, show_parsed],
outputs=[raw_out, parsed_out],
)
gr.Markdown("### Tips")
gr.Markdown("- Set MODEL_PATH at the top to your merged model repo or local path.\n"
"- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n"
"- Your training format (chat tokens + Task/Rules) is preserved in the prompt.\n"
"- **Enhanced parsing**: Stops at last sentence using 5-word match to avoid gibberish.")
return demo
if __name__ == "__main__":
app = launch_app()
app.launch(share=True) |