import os import json import warnings import torch import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM # ---------- CONFIG ---------- os.environ.setdefault("GRADIO_SERVER_PORT", "7860") MODEL_PATH = "iqasimz/g2" # <- change to your repo or local dir MAX_NEW_TOKENS_DEFAULT = 300 TEMPERATURE_DEFAULT = 0 TOP_P_DEFAULT = 1.0 # --------------------------- warnings.filterwarnings("ignore", module="torch") _model_cache = {} def _ensure_pad_token(tokenizer): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def load_model_to_cpu(model_dir: str): """Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU.""" if model_dir in _model_cache: return _model_cache[model_dir] tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tok = _ensure_pad_token(tok) mdl = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, torch_dtype=torch.float16, # model runs in fp16 when moved to GPU device_map=None, # keep on CPU for caching ) mdl.eval() _model_cache[model_dir] = (tok, mdl) print(f"[cache] Loaded {model_dir} on CPU") return tok, mdl def build_inference_prompt(paragraph: str) -> str: # Match your training format EXACTLY (Task + Rules + Paragraph in user turn) task_block = """Task: ou are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert.Number the sentences in the paragraph and tag the role of each one.\n Rules:\n - Do NOT change the text of any sentence.\n - Keep the original order.\n - Output exactly N lines, one per sentence.\n - Each line must be: " ", where role ∈ {claim, premise, none}.\n - Do not add any explanations or extra text after the Nth line. """ # Chat-style formatting used during training return ( f"<|im_start|>user\n{task_block}\nParagraph:\n{paragraph}" f"<|im_end|>\n<|im_start|>assistant\n" ) def get_last_five_words(text: str) -> str: """Get the last 5 words from a text string.""" words = text.strip().split() return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words) def extract_role_from_suffix(text_after_match: str) -> str: """ Extract role (claim, premise, none) from text after the 5-word match. Handles cases like 'claimabcd' -> 'claim' """ text_after_match = text_after_match.strip() # Look for the role words at the start of the remaining text role_words = ['claim', 'premise', 'none'] for role in role_words: if text_after_match.lower().startswith(role.lower()): return role # If no exact match, return the first word (fallback) first_word = text_after_match.split()[0] if text_after_match.split() else "" for role in role_words: if first_word.lower().startswith(role.lower()): return role return "none" # default fallback def parse_numbered_lines(text: str, original_paragraph: str): """ Enhanced parsing with improved stopping criteria: 1. Find exact match of last 5 words from input paragraph 2. Look for role word after a space following the match 3. Stop parsing after finding the last sentence to avoid gibberish """ results = [] lines = text.splitlines() # Get sentences from original paragraph for reference import re sentences = re.split(r'[.!?]+', original_paragraph.strip()) sentences = [s.strip() for s in sentences if s.strip()] if not sentences: return results # Get last 5 words of the original paragraph last_five_words = get_last_five_words(original_paragraph) for line in lines: line = line.strip() if not line or not line[0].isdigit(): continue try: # Parse index space_after_idx = line.find(" ") if space_after_idx == -1: continue idx = int(line[:space_after_idx]) rest = line[space_after_idx + 1:].rstrip() # Check if this line contains the last 5 words (indicating last sentence) if last_five_words.lower() in rest.lower(): # Find the position of the last 5 words match_pos = rest.lower().find(last_five_words.lower()) if match_pos != -1: # Extract sentence (everything up to and including the match) sentence_end = match_pos + len(last_five_words) sent = rest[:sentence_end].strip() # Look for role after the match text_after_match = rest[sentence_end:].strip() role = "none" # default if text_after_match: # Skip any immediate punctuation/spaces and look for role text_after_match = text_after_match.lstrip(' .,!?') role = extract_role_from_suffix(text_after_match) results.append({"index": idx, "sentence": sent, "role": role}) # STOP parsing here - this is the last sentence break else: # Regular parsing for non-last sentences last_space = rest.rfind(" ") if last_space == -1: continue sent = rest[:last_space].strip() role_candidate = rest[last_space + 1:].strip().lower() # Clean role (handle gibberish suffixes) role = "none" for valid_role in ['claim', 'premise', 'none']: if role_candidate.startswith(valid_role): role = valid_role break results.append({"index": idx, "sentence": sent, "role": role}) except Exception as e: print(f"Error parsing line '{line}': {e}") continue return results @spaces.GPU(duration=120) def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: float, show_parsed: bool): paragraph = (paragraph or "").strip() if not paragraph: return "Please paste a paragraph.", "" tokenizer, model = load_model_to_cpu(MODEL_PATH) model = model.to("cuda") prompt = build_inference_prompt(paragraph) inputs = tokenizer(prompt, return_tensors="pt").to("cuda") with torch.inference_mode(): output = model.generate( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=(temperature > 0.0 and top_p < 1.0), pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, ) full = tokenizer.decode(output[0], skip_special_tokens=False) # Extract assistant segment if "<|im_start|>assistant\n" in full: resp = full.split("<|im_start|>assistant\n")[-1] resp = resp.split("<|im_end|>")[0].strip() else: resp = full.strip() # Updated parsing with original paragraph reference parsed = parse_numbered_lines(resp, paragraph) parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else "" return resp, parsed_json def launch_app(): with gr.Blocks(title="Argument Role Tagger (DeepSeek 1.5B + LoRA merged)") as demo: gr.Markdown("## Argument Role Tagger") gr.Markdown( "Paste a paragraph. The model will number sentences and label each as **claim**, **premise**, or **none**." ) with gr.Row(): with gr.Column(scale=2): paragraph = gr.Textbox( label="Paragraph", lines=10, placeholder="Paste your paragraph…", value=("Governments should subsidize solar panels to accelerate clean energy adoption. " "Lowering installation costs would encourage more households to switch, reducing fossil fuel dependence. " "In the long run, this shift could stabilize energy prices and reduce environmental damage.") ) with gr.Row(): max_new_tokens = gr.Slider(200, 4300, value=MAX_NEW_TOKENS_DEFAULT, step=100, label="Max new tokens") with gr.Row(): temperature = gr.Slider(0.0, 1.0, value=TEMPERATURE_DEFAULT, step=0.05, label="Temperature") top_p = gr.Slider(0.5, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p") show_parsed = gr.Checkbox(value=True, label="Show parsed JSON") run_btn = gr.Button("Analyze", variant="primary") with gr.Column(scale=3): raw_out = gr.Textbox(label="Model Output (raw)", lines=18, show_copy_button=True) parsed_out = gr.Code(label="Parsed JSON", language="json") run_btn.click( analyze, inputs=[paragraph, max_new_tokens, temperature, top_p, show_parsed], outputs=[raw_out, parsed_out], ) gr.Markdown("### Tips") gr.Markdown("- Set MODEL_PATH at the top to your merged model repo or local path.\n" "- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n" "- Your training format (chat tokens + Task/Rules) is preserved in the prompt.\n" "- **Enhanced parsing**: Stops at last sentence using 5-word match to avoid gibberish.") return demo if __name__ == "__main__": app = launch_app() app.launch(share=True)