| | import os |
| | import json |
| | import warnings |
| | import torch |
| | import gradio as gr |
| | import spaces |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| | |
| | os.environ.setdefault("GRADIO_SERVER_PORT", "7860") |
| | MODEL_PATH = "iqasimz/g2" |
| | MAX_NEW_TOKENS_DEFAULT = 300 |
| | TEMPERATURE_DEFAULT = 0 |
| | TOP_P_DEFAULT = 1.0 |
| | |
| |
|
| | warnings.filterwarnings("ignore", module="torch") |
| | _model_cache = {} |
| |
|
| | def _ensure_pad_token(tokenizer): |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | return tokenizer |
| |
|
| | def load_model_to_cpu(model_dir: str): |
| | """Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU.""" |
| | if model_dir in _model_cache: |
| | return _model_cache[model_dir] |
| |
|
| | tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
| | tok = _ensure_pad_token(tok) |
| |
|
| | mdl = AutoModelForCausalLM.from_pretrained( |
| | model_dir, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float16, |
| | device_map=None, |
| | ) |
| | mdl.eval() |
| | _model_cache[model_dir] = (tok, mdl) |
| | print(f"[cache] Loaded {model_dir} on CPU") |
| | return tok, mdl |
| |
|
| | def build_inference_prompt(paragraph: str) -> str: |
| | |
| | task_block = """Task: ou are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert.Number the sentences in the paragraph and tag the role of each one.\n |
| | Rules:\n |
| | - Do NOT change the text of any sentence.\n |
| | - Keep the original order.\n |
| | - Output exactly N lines, one per sentence.\n |
| | - Each line must be: "<index> <original sentence> <role>", where role ∈ {claim, premise, none}.\n |
| | - Do not add any explanations or extra text after the Nth line. |
| | """ |
| | |
| | return ( |
| | f"<|im_start|>user\n{task_block}\nParagraph:\n{paragraph}" |
| | f"<|im_end|>\n<|im_start|>assistant\n" |
| | ) |
| |
|
| | def get_last_five_words(text: str) -> str: |
| | """Get the last 5 words from a text string.""" |
| | words = text.strip().split() |
| | return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words) |
| |
|
| | def extract_role_from_suffix(text_after_match: str) -> str: |
| | """ |
| | Extract role (claim, premise, none) from text after the 5-word match. |
| | Handles cases like 'claimabcd' -> 'claim' |
| | """ |
| | text_after_match = text_after_match.strip() |
| | |
| | |
| | role_words = ['claim', 'premise', 'none'] |
| | for role in role_words: |
| | if text_after_match.lower().startswith(role.lower()): |
| | return role |
| | |
| | |
| | first_word = text_after_match.split()[0] if text_after_match.split() else "" |
| | for role in role_words: |
| | if first_word.lower().startswith(role.lower()): |
| | return role |
| | |
| | return "none" |
| |
|
| | def parse_numbered_lines(text: str, original_paragraph: str): |
| | """ |
| | Enhanced parsing with improved stopping criteria: |
| | 1. Find exact match of last 5 words from input paragraph |
| | 2. Look for role word after a space following the match |
| | 3. Stop parsing after finding the last sentence to avoid gibberish |
| | """ |
| | results = [] |
| | lines = text.splitlines() |
| | |
| | |
| | import re |
| | sentences = re.split(r'[.!?]+', original_paragraph.strip()) |
| | sentences = [s.strip() for s in sentences if s.strip()] |
| | |
| | if not sentences: |
| | return results |
| | |
| | |
| | last_five_words = get_last_five_words(original_paragraph) |
| | |
| | for line in lines: |
| | line = line.strip() |
| | if not line or not line[0].isdigit(): |
| | continue |
| | |
| | try: |
| | |
| | space_after_idx = line.find(" ") |
| | if space_after_idx == -1: |
| | continue |
| | |
| | idx = int(line[:space_after_idx]) |
| | rest = line[space_after_idx + 1:].rstrip() |
| | |
| | |
| | if last_five_words.lower() in rest.lower(): |
| | |
| | match_pos = rest.lower().find(last_five_words.lower()) |
| | if match_pos != -1: |
| | |
| | sentence_end = match_pos + len(last_five_words) |
| | sent = rest[:sentence_end].strip() |
| | |
| | |
| | text_after_match = rest[sentence_end:].strip() |
| | role = "none" |
| | |
| | if text_after_match: |
| | |
| | text_after_match = text_after_match.lstrip(' .,!?') |
| | role = extract_role_from_suffix(text_after_match) |
| | |
| | results.append({"index": idx, "sentence": sent, "role": role}) |
| | |
| | |
| | break |
| | else: |
| | |
| | last_space = rest.rfind(" ") |
| | if last_space == -1: |
| | continue |
| | |
| | sent = rest[:last_space].strip() |
| | role_candidate = rest[last_space + 1:].strip().lower() |
| | |
| | |
| | role = "none" |
| | for valid_role in ['claim', 'premise', 'none']: |
| | if role_candidate.startswith(valid_role): |
| | role = valid_role |
| | break |
| | |
| | results.append({"index": idx, "sentence": sent, "role": role}) |
| | |
| | except Exception as e: |
| | print(f"Error parsing line '{line}': {e}") |
| | continue |
| | |
| | return results |
| |
|
| | @spaces.GPU(duration=120) |
| | def analyze(paragraph: str, max_new_tokens: int, temperature: float, top_p: float, show_parsed: bool): |
| | paragraph = (paragraph or "").strip() |
| | if not paragraph: |
| | return "Please paste a paragraph.", "" |
| |
|
| | tokenizer, model = load_model_to_cpu(MODEL_PATH) |
| | model = model.to("cuda") |
| |
|
| | prompt = build_inference_prompt(paragraph) |
| | inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
| |
|
| | with torch.inference_mode(): |
| | output = model.generate( |
| | **inputs, |
| | max_new_tokens=int(max_new_tokens), |
| | temperature=float(temperature), |
| | top_p=float(top_p), |
| | do_sample=(temperature > 0.0 and top_p < 1.0), |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | use_cache=True, |
| | ) |
| |
|
| | full = tokenizer.decode(output[0], skip_special_tokens=False) |
| |
|
| | |
| | if "<|im_start|>assistant\n" in full: |
| | resp = full.split("<|im_start|>assistant\n")[-1] |
| | resp = resp.split("<|im_end|>")[0].strip() |
| | else: |
| | resp = full.strip() |
| |
|
| | |
| | parsed = parse_numbered_lines(resp, paragraph) |
| | parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else "" |
| | return resp, parsed_json |
| |
|
| | def launch_app(): |
| | with gr.Blocks(title="Argument Role Tagger (DeepSeek 1.5B + LoRA merged)") as demo: |
| | gr.Markdown("## Argument Role Tagger") |
| | gr.Markdown( |
| | "Paste a paragraph. The model will number sentences and label each as **claim**, **premise**, or **none**." |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | paragraph = gr.Textbox( |
| | label="Paragraph", |
| | lines=10, |
| | placeholder="Paste your paragraph…", |
| | value=("Governments should subsidize solar panels to accelerate clean energy adoption. " |
| | "Lowering installation costs would encourage more households to switch, reducing fossil fuel dependence. " |
| | "In the long run, this shift could stabilize energy prices and reduce environmental damage.") |
| | ) |
| | with gr.Row(): |
| | max_new_tokens = gr.Slider(200, 4300, value=MAX_NEW_TOKENS_DEFAULT, step=100, label="Max new tokens") |
| | with gr.Row(): |
| | temperature = gr.Slider(0.0, 1.0, value=TEMPERATURE_DEFAULT, step=0.05, label="Temperature") |
| | top_p = gr.Slider(0.5, 1.0, value=TOP_P_DEFAULT, step=0.05, label="Top-p") |
| | show_parsed = gr.Checkbox(value=True, label="Show parsed JSON") |
| | run_btn = gr.Button("Analyze", variant="primary") |
| |
|
| | with gr.Column(scale=3): |
| | raw_out = gr.Textbox(label="Model Output (raw)", lines=18, show_copy_button=True) |
| | parsed_out = gr.Code(label="Parsed JSON", language="json") |
| |
|
| | run_btn.click( |
| | analyze, |
| | inputs=[paragraph, max_new_tokens, temperature, top_p, show_parsed], |
| | outputs=[raw_out, parsed_out], |
| | ) |
| |
|
| | gr.Markdown("### Tips") |
| | gr.Markdown("- Set MODEL_PATH at the top to your merged model repo or local path.\n" |
| | "- For deterministic outputs, set Temperature=0.0 and Top-p=1.0.\n" |
| | "- Your training format (chat tokens + Task/Rules) is preserved in the prompt.\n" |
| | "- **Enhanced parsing**: Stops at last sentence using 5-word match to avoid gibberish.") |
| |
|
| | return demo |
| |
|
| | if __name__ == "__main__": |
| | app = launch_app() |
| | app.launch(share=True) |