ArgumentAnalyst / app.py
iqasimz's picture
Update app.py
6b50cf8 verified
import os
import json
import warnings
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
# ---------- CONFIG ----------
MODEL_PATH = "iqasimz/g5"
MAX_NEW_TOKENS_DEFAULT = 500
TOP_P_DEFAULT = 1.0
# ---------------------------
warnings.filterwarnings("ignore", module="torch")
_model_cache = {}
def _ensure_pad_token(tokenizer):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def load_model_to_cpu(model_dir: str):
"""Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU."""
if model_dir in _model_cache:
return _model_cache[model_dir]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tok = _ensure_pad_token(tok)
mdl = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map=None, # keep on CPU for caching
)
mdl.eval()
_model_cache[model_dir] = (tok, mdl)
print(f"[cache] Loaded {model_dir} on CPU")
return tok, mdl
def build_prompt(paragraph: str) -> str:
"""Format the user paragraph into the EXACT structured instruction format."""
return (
"Task: You are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert."
"Number the sentences in the paragraph and tag the role of each one.\n"
"Rules:\n"
"- Do NOT change the text of any sentence.\n"
"- Keep the original order.\n"
"- Output exactly N lines, one per sentence.\n"
"- Each line must be: \"<index> <original sentence> <role>\", where role ∈ {claim, premise, none}.\n"
"\n"
"- Do not add any explanations or extra text after the Nth line.\n"
f"Paragraph:\n{paragraph.strip()}"
)
# ---------------- JSON Parsing Utilities ----------------
def get_last_five_words(text: str) -> str:
words = text.strip().split()
return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words)
def extract_role_from_suffix(text_after_match: str) -> str:
text_after_match = text_after_match.strip()
role_words = ['claim', 'premise', 'none']
for role in role_words:
if text_after_match.lower().startswith(role.lower()):
return role
first_word = text_after_match.split()[0] if text_after_match.split() else ""
for role in role_words:
if first_word.lower().startswith(role.lower()):
return role
return "none"
def parse_numbered_lines(text: str, original_paragraph: str):
results = []
lines = text.splitlines()
import re
sentences = re.split(r'[.!?]+', original_paragraph.strip())
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return results
last_five_words = get_last_five_words(original_paragraph)
for line in lines:
line = line.strip()
if not line or not line[0].isdigit():
continue
try:
space_after_idx = line.find(" ")
if space_after_idx == -1:
continue
idx = int(line[:space_after_idx])
rest = line[space_after_idx + 1:].rstrip()
if last_five_words.lower() in rest.lower():
match_pos = rest.lower().find(last_five_words.lower())
if match_pos != -1:
sentence_end = match_pos + len(last_five_words)
sent = rest[:sentence_end].strip()
text_after_match = rest[sentence_end:].strip()
role = "none"
if text_after_match:
text_after_match = text_after_match.lstrip(' .,!?')
role = extract_role_from_suffix(text_after_match)
results.append({"index": idx, "sentence": sent, "role": role})
break
else:
last_space = rest.rfind(" ")
if last_space == -1:
continue
sent = rest[:last_space].strip()
role_candidate = rest[last_space + 1:].strip().lower()
role = "none"
for valid_role in ['claim', 'premise', 'none']:
if role_candidate.startswith(valid_role):
role = valid_role
break
results.append({"index": idx, "sentence": sent, "role": role})
except Exception as e:
print(f"Error parsing line '{line}': {e}")
continue
return results
# --------------------------------------------------------
@spaces.GPU(duration=120)
def generate_text(paragraph, max_tokens, show_parsed):
if not paragraph.strip():
return "Please enter some text.", ""
tokenizer, model = load_model_to_cpu(MODEL_PATH)
model = model.to("cuda")
formatted_input = build_prompt(paragraph)
messages = [{"role": "user", "content": formatted_input}]
formatted_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(formatted_text, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_tokens),
top_p=TOP_P_DEFAULT,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
if "<|Assistant|>" in full_response:
response = full_response.split("<|Assistant|>")[-1]
response = response.split("<|end▁of▁sentence|>")[0].strip()
else:
new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
parsed = parse_numbered_lines(response, paragraph)
parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
return response, parsed_json
def launch_app():
with gr.Blocks(title="iqasimz/g3 - Argument Role Tagger") as demo:
gr.Markdown("# iqasimz/g3 - Argument Role Tagger")
gr.Markdown("Enter a paragraph, the model will number sentences and assign roles (claim, premise, none).")
with gr.Row():
with gr.Column():
input_para = gr.Textbox(
label="Input Paragraph",
lines=8,
placeholder="Paste your paragraph here..."
)
max_tokens = gr.Slider(
minimum=50,
maximum=5000,
value=MAX_NEW_TOKENS_DEFAULT,
step=50,
label="Max New Tokens"
)
show_parsed = gr.Checkbox(value=True, label="Show parsed JSON")
generate_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Model Output",
lines=15,
show_copy_button=True
)
parsed_out = gr.Code(
label="Parsed JSON",
language="json"
)
generate_btn.click(
fn=generate_text,
inputs=[input_para, max_tokens, show_parsed],
outputs=[output_text, parsed_out]
)
return demo
if __name__ == "__main__":
app = launch_app()
app.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
show_error=True
)