test2 / app.py
iqasimz's picture
Update app.py
1901d9d verified
import os
import json
import warnings
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
# ---------- CONFIG ----------
os.environ.setdefault("GRADIO_SERVER_PORT", "7860")
MODEL_PATH = "iqasimz/g2" # <- change to your repo or local dir
MAX_NEW_TOKENS_DEFAULT = 300
TEMPERATURE_DEFAULT = 0
TOP_P_DEFAULT = 1.0
# ---------------------------
warnings.filterwarnings("ignore", module="torch")
_model_cache = {}
def _ensure_pad_token(tokenizer):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def load_model_to_cpu(model_dir: str):
"""Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU."""
if model_dir in _model_cache:
return _model_cache[model_dir]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tok = _ensure_pad_token(tok)
mdl = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.float16, # model runs in fp16 when moved to GPU
device_map=None, # keep on CPU for caching
)
mdl.eval()
_model_cache[model_dir] = (tok, mdl)
print(f"[cache] Loaded {model_dir} on CPU")
return tok, mdl
def build_inference_prompt(paragraph: str) -> str:
# Match your training format EXACTLY (Task + Rules + Paragraph in user turn)
task_block = """Task: ou are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert.Number the sentences in the paragraph and tag the role of each one.\n
Rules:\n
- Do NOT change the text of any sentence.\n
- Keep the original order.\n
- Output exactly N lines, one per sentence.\n
- Each line must be: "<index> <original sentence> <role>", where role ∈ {claim, premise, none}.\n
- Do not add any explanations or extra text after the Nth line.
"""
# Chat-style formatting used during training
return (
f"<|im_start|>user\n{task_block}\nParagraph:\n{paragraph}"
f"<|im_end|>\n<|im_start|>assistant\n"
)
def get_last_five_words(text: str) -> str:
"""Get the last 5 words from a text string."""
words = text.strip().split()
return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words)
def extract_role_from_suffix(text_after_match: str) -> str:
"""
Extract role (claim, premise, none) from text after the 5-word match.
Handles cases like 'claimabcd' -> 'claim'
"""
text_after_match = text_after_match.strip()
# Look for the role words at the start of the remaining text
role_words = ['claim', 'premise', 'none']
for role in role_words:
if text_after_match.lower().startswith(role.lower()):
return role
# If no exact match, return the first word (fallback)
first_word = text_after_match.split()[0] if text_after_match.split() else ""
for role in role_words:
if first_word.lower().startswith(role.lower()):
return role
return "none" # default fallback
def parse_numbered_lines(text: str, original_paragraph: str):
"""
Enhanced parsing with improved stopping criteria:
1. Find exact match of last 5 words from input paragraph
2. Look for role word after a space following the match
3. Stop parsing after finding the last sentence to avoid gibberish
"""
results = []
lines = text.splitlines()
# Get sentences from original paragraph for reference
import re
sentences = re.split(r'[.!?]+', original_paragraph.strip())
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return results
# Get last 5 words of the original paragraph
last_five_words = get_last_five_words(original_paragraph)
for line in lines:
line = line.strip()
if not line or not line[0].isdigit():
continue
try:
# Parse index
space_after_idx = line.find(" ")
if space_after_idx == -1:
continue
idx = int(line[:space_after_idx])
rest = line[space_after_idx + 1:].rstrip()
# Check if this line contains the last 5 words (indicating last sentence)
if last_five_words.lower() in rest.lower():
# Find the position of the last 5 words
match_pos = rest.lower().find(last_five_words.lower())
if match_pos != -1:
# Extract sentence (everything up to and including the match)
sentence_end = match_pos + len(last_five_words)
sent = rest[:sentence_end].strip()
# Look for role after the match
text_after_match = rest[sentence_end:].strip()
role = "none" # default
if text_after_match:
# Skip any immediate punctuation/spaces and look for role
text_after_match = text_after_match.lstrip(' .,!?')
role = extract_role_from_suffix(text_after_match)
results.append({"index": idx, "sentence": sent, "role": role})
# STOP parsing here - this is the last sentence
break
else:
# Regular parsing for non-last sentences
last_space = rest.rfind(" ")
if last_space == -1:
continue
sent = rest[:last_space].strip()
role_candidate = rest[last_space + 1:].strip().lower()
# Clean role (handle gibberish suffixes)
role = "none"
for valid_role in ['claim', 'premise', 'none']:
if role_candidate.startswith(valid_role):
role = valid_role
break
results.append({"index": idx, "sentence": sent, "role": role})
except Exception as e:
print(f"Error parsing line '{line}': {e}")
continue
return results
@spaces.GPU(duration=120)
def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: float, show_parsed: bool):
paragraph = (paragraph or "").strip()
if not paragraph:
return "Please paste a paragraph.", ""
tokenizer, model = load_model_to_cpu(MODEL_PATH)
model = model.to("cuda")
prompt = build_inference_prompt(paragraph)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=(temperature > 0.0 and top_p < 1.0),
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
full = tokenizer.decode(output[0], skip_special_tokens=False)
# Extract assistant segment
if "<|im_start|>assistant\n" in full:
resp = full.split("<|im_start|>assistant\n")[-1]
resp = resp.split("<|im_end|>")[0].strip()
else:
resp = full.strip()
# Updated parsing with original paragraph reference
parsed = parse_numbered_lines(resp, paragraph)
parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else ""
return resp, parsed_json
def launch_app():
with gr.Blocks(title="Argument Role Tagger (DeepSeek 1.5B + LoRA merged)") as demo:
gr.Markdown("## Argument Role Tagger")
gr.Markdown(
"Paste a paragraph. The model will number sentences and label each as **claim**, **premise**, or **none**."
)
with gr.Row():
with gr.Column(scale=2):
paragraph = gr.Textbox(
label="Paragraph",
lines=10,
placeholder="Paste your paragraph…",
value=("Governments should subsidize solar panels to accelerate clean energy adoption. "
"Lowering installation costs would encourage more households to switch, reducing fossil fuel dependence. "
"In the long run, this shift could stabilize energy prices and reduce environmental damage.")
)
with gr.Row():
max_new_tokens = gr.Slider(200, 4300, value=MAX_NEW_TOKENS_DEFAULT, step=100, label="Max new tokens")
with gr.Row():
temperature = gr.Slider(0.0, 1.0, value=TEMPERATURE_DEFAULT, step=0.05, label="Temperature")
top_p = gr.Slider(0.5, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p")
show_parsed = gr.Checkbox(value=True, label="Show parsed JSON")
run_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=3):
raw_out = gr.Textbox(label="Model Output (raw)", lines=18, show_copy_button=True)
parsed_out = gr.Code(label="Parsed JSON", language="json")
run_btn.click(
analyze,
inputs=[paragraph, max_new_tokens, temperature, top_p, show_parsed],
outputs=[raw_out, parsed_out],
)
gr.Markdown("### Tips")
gr.Markdown("- Set MODEL_PATH at the top to your merged model repo or local path.\n"
"- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n"
"- Your training format (chat tokens + Task/Rules) is preserved in the prompt.\n"
"- **Enhanced parsing**: Stops at last sentence using 5-word match to avoid gibberish.")
return demo
if __name__ == "__main__":
app = launch_app()
app.launch(share=True)