Spaces:
Sleeping
Sleeping
File size: 7,885 Bytes
efc82e9 07c0393 efc82e9 6b50cf8 efc82e9 4f236b0 efc82e9 e2946b2 efc82e9 e2946b2 0f7be3f e2946b2 0f7be3f e2946b2 0f7be3f e2946b2 07c0393 efc82e9 07c0393 e2946b2 07c0393 efc82e9 e2946b2 efc82e9 e2946b2 efc82e9 e2946b2 efc82e9 e2946b2 efc82e9 185e65b efc82e9 e2946b2 efc82e9 e2946b2 efc82e9 07c0393 efc82e9 e2946b2 efc82e9 e2946b2 efc82e9 e2946b2 efc82e9 07c0393 e2946b2 efc82e9 07c0393 efc82e9 07c0393 efc82e9 547aa17 efc82e9 e2946b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | import os
import json
import warnings
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
# ---------- CONFIG ----------
MODEL_PATH = "iqasimz/g5"
MAX_NEW_TOKENS_DEFAULT = 500
TOP_P_DEFAULT = 1.0
# ---------------------------
warnings.filterwarnings("ignore", module="torch")
_model_cache = {}
def _ensure_pad_token(tokenizer):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def load_model_to_cpu(model_dir: str):
"""Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU."""
if model_dir in _model_cache:
return _model_cache[model_dir]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tok = _ensure_pad_token(tok)
mdl = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map=None, # keep on CPU for caching
)
mdl.eval()
_model_cache[model_dir] = (tok, mdl)
print(f"[cache] Loaded {model_dir} on CPU")
return tok, mdl
def build_prompt(paragraph: str) -> str:
"""Format the user paragraph into the EXACT structured instruction format."""
return (
"Task: You are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert."
"Number the sentences in the paragraph and tag the role of each one.\n"
"Rules:\n"
"- Do NOT change the text of any sentence.\n"
"- Keep the original order.\n"
"- Output exactly N lines, one per sentence.\n"
"- Each line must be: \"<index> <original sentence> <role>\", where role ∈ {claim, premise, none}.\n"
"\n"
"- Do not add any explanations or extra text after the Nth line.\n"
f"Paragraph:\n{paragraph.strip()}"
)
# ---------------- JSON Parsing Utilities ----------------
def get_last_five_words(text: str) -> str:
words = text.strip().split()
return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words)
def extract_role_from_suffix(text_after_match: str) -> str:
text_after_match = text_after_match.strip()
role_words = ['claim', 'premise', 'none']
for role in role_words:
if text_after_match.lower().startswith(role.lower()):
return role
first_word = text_after_match.split()[0] if text_after_match.split() else ""
for role in role_words:
if first_word.lower().startswith(role.lower()):
return role
return "none"
def parse_numbered_lines(text: str, original_paragraph: str):
results = []
lines = text.splitlines()
import re
sentences = re.split(r'[.!?]+', original_paragraph.strip())
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return results
last_five_words = get_last_five_words(original_paragraph)
for line in lines:
line = line.strip()
if not line or not line[0].isdigit():
continue
try:
space_after_idx = line.find(" ")
if space_after_idx == -1:
continue
idx = int(line[:space_after_idx])
rest = line[space_after_idx + 1:].rstrip()
if last_five_words.lower() in rest.lower():
match_pos = rest.lower().find(last_five_words.lower())
if match_pos != -1:
sentence_end = match_pos + len(last_five_words)
sent = rest[:sentence_end].strip()
text_after_match = rest[sentence_end:].strip()
role = "none"
if text_after_match:
text_after_match = text_after_match.lstrip(' .,!?')
role = extract_role_from_suffix(text_after_match)
results.append({"index": idx, "sentence": sent, "role": role})
break
else:
last_space = rest.rfind(" ")
if last_space == -1:
continue
sent = rest[:last_space].strip()
role_candidate = rest[last_space + 1:].strip().lower()
role = "none"
for valid_role in ['claim', 'premise', 'none']:
if role_candidate.startswith(valid_role):
role = valid_role
break
results.append({"index": idx, "sentence": sent, "role": role})
except Exception as e:
print(f"Error parsing line '{line}': {e}")
continue
return results
# --------------------------------------------------------
@spaces.GPU(duration=120)
def generate_text(paragraph, max_tokens, show_parsed):
if not paragraph.strip():
return "Please enter some text.", ""
tokenizer, model = load_model_to_cpu(MODEL_PATH)
model = model.to("cuda")
formatted_input = build_prompt(paragraph)
messages = [{"role": "user", "content": formatted_input}]
formatted_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(formatted_text, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_tokens),
top_p=TOP_P_DEFAULT,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
if "<|Assistant|>" in full_response:
response = full_response.split("<|Assistant|>")[-1]
response = response.split("<|end▁of▁sentence|>")[0].strip()
else:
new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
parsed = parse_numbered_lines(response, paragraph)
parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
return response, parsed_json
def launch_app():
with gr.Blocks(title="iqasimz/g3 - Argument Role Tagger") as demo:
gr.Markdown("# iqasimz/g3 - Argument Role Tagger")
gr.Markdown("Enter a paragraph, the model will number sentences and assign roles (claim, premise, none).")
with gr.Row():
with gr.Column():
input_para = gr.Textbox(
label="Input Paragraph",
lines=8,
placeholder="Paste your paragraph here..."
)
max_tokens = gr.Slider(
minimum=50,
maximum=5000,
value=MAX_NEW_TOKENS_DEFAULT,
step=50,
label="Max New Tokens"
)
show_parsed = gr.Checkbox(value=True, label="Show parsed JSON")
generate_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Model Output",
lines=15,
show_copy_button=True
)
parsed_out = gr.Code(
label="Parsed JSON",
language="json"
)
generate_btn.click(
fn=generate_text,
inputs=[input_para, max_tokens, show_parsed],
outputs=[output_text, parsed_out]
)
return demo
if __name__ == "__main__":
app = launch_app()
app.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
show_error=True
) |