import os import json import warnings import torch import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM # ---------- CONFIG ---------- MODEL_PATH = "iqasimz/g5" MAX_NEW_TOKENS_DEFAULT = 500 TOP_P_DEFAULT = 1.0 # --------------------------- warnings.filterwarnings("ignore", module="torch") _model_cache = {} def _ensure_pad_token(tokenizer): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def load_model_to_cpu(model_dir: str): """Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU.""" if model_dir in _model_cache: return _model_cache[model_dir] tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tok = _ensure_pad_token(tok) mdl = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, torch_dtype=torch.float16, device_map=None, # keep on CPU for caching ) mdl.eval() _model_cache[model_dir] = (tok, mdl) print(f"[cache] Loaded {model_dir} on CPU") return tok, mdl def build_prompt(paragraph: str) -> str: """Format the user paragraph into the EXACT structured instruction format.""" return ( "Task: You are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert." "Number the sentences in the paragraph and tag the role of each one.\n" "Rules:\n" "- Do NOT change the text of any sentence.\n" "- Keep the original order.\n" "- Output exactly N lines, one per sentence.\n" "- Each line must be: \" \", where role ∈ {claim, premise, none}.\n" "\n" "- Do not add any explanations or extra text after the Nth line.\n" f"Paragraph:\n{paragraph.strip()}" ) # ---------------- JSON Parsing Utilities ---------------- def get_last_five_words(text: str) -> str: words = text.strip().split() return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words) def extract_role_from_suffix(text_after_match: str) -> str: text_after_match = text_after_match.strip() role_words = ['claim', 'premise', 'none'] for role in role_words: if text_after_match.lower().startswith(role.lower()): return role first_word = text_after_match.split()[0] if text_after_match.split() else "" for role in role_words: if first_word.lower().startswith(role.lower()): return role return "none" def parse_numbered_lines(text: str, original_paragraph: str): results = [] lines = text.splitlines() import re sentences = re.split(r'[.!?]+', original_paragraph.strip()) sentences = [s.strip() for s in sentences if s.strip()] if not sentences: return results last_five_words = get_last_five_words(original_paragraph) for line in lines: line = line.strip() if not line or not line[0].isdigit(): continue try: space_after_idx = line.find(" ") if space_after_idx == -1: continue idx = int(line[:space_after_idx]) rest = line[space_after_idx + 1:].rstrip() if last_five_words.lower() in rest.lower(): match_pos = rest.lower().find(last_five_words.lower()) if match_pos != -1: sentence_end = match_pos + len(last_five_words) sent = rest[:sentence_end].strip() text_after_match = rest[sentence_end:].strip() role = "none" if text_after_match: text_after_match = text_after_match.lstrip(' .,!?') role = extract_role_from_suffix(text_after_match) results.append({"index": idx, "sentence": sent, "role": role}) break else: last_space = rest.rfind(" ") if last_space == -1: continue sent = rest[:last_space].strip() role_candidate = rest[last_space + 1:].strip().lower() role = "none" for valid_role in ['claim', 'premise', 'none']: if role_candidate.startswith(valid_role): role = valid_role break results.append({"index": idx, "sentence": sent, "role": role}) except Exception as e: print(f"Error parsing line '{line}': {e}") continue return results # -------------------------------------------------------- @spaces.GPU(duration=120) def generate_text(paragraph, max_tokens, show_parsed): if not paragraph.strip(): return "Please enter some text.", "" tokenizer, model = load_model_to_cpu(MODEL_PATH) model = model.to("cuda") formatted_input = build_prompt(paragraph) messages = [{"role": "user", "content": formatted_input}] formatted_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(formatted_text, return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=int(max_tokens), top_p=TOP_P_DEFAULT, do_sample=False, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, ) full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) if "<|Assistant|>" in full_response: response = full_response.split("<|Assistant|>")[-1] response = response.split("<|end▁of▁sentence|>")[0].strip() else: new_tokens = outputs[0][inputs.input_ids.shape[-1]:] response = tokenizer.decode(new_tokens, skip_special_tokens=True) parsed = parse_numbered_lines(response, paragraph) parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else "" return response, parsed_json def launch_app(): with gr.Blocks(title="iqasimz/g3 - Argument Role Tagger") as demo: gr.Markdown("# iqasimz/g3 - Argument Role Tagger") gr.Markdown("Enter a paragraph, the model will number sentences and assign roles (claim, premise, none).") with gr.Row(): with gr.Column(): input_para = gr.Textbox( label="Input Paragraph", lines=8, placeholder="Paste your paragraph here..." ) max_tokens = gr.Slider( minimum=50, maximum=5000, value=MAX_NEW_TOKENS_DEFAULT, step=50, label="Max New Tokens" ) show_parsed = gr.Checkbox(value=True, label="Show parsed JSON") generate_btn = gr.Button("Analyze", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Model Output", lines=15, show_copy_button=True ) parsed_out = gr.Code( label="Parsed JSON", language="json" ) generate_btn.click( fn=generate_text, inputs=[input_para, max_tokens, show_parsed], outputs=[output_text, parsed_out] ) return demo if __name__ == "__main__": app = launch_app() app.launch( server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), show_error=True )