Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import warnings | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ---------- CONFIG ---------- | |
| MODEL_PATH = "iqasimz/g5" | |
| MAX_NEW_TOKENS_DEFAULT = 500 | |
| TOP_P_DEFAULT = 1.0 | |
| # --------------------------- | |
| warnings.filterwarnings("ignore", module="torch") | |
| _model_cache = {} | |
| def _ensure_pad_token(tokenizer): | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| def load_model_to_cpu(model_dir: str): | |
| """Load tokenizer+model once on CPU; moved to GPU per request via @spaces.GPU.""" | |
| if model_dir in _model_cache: | |
| return _model_cache[model_dir] | |
| tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | |
| tok = _ensure_pad_token(tok) | |
| mdl = AutoModelForCausalLM.from_pretrained( | |
| model_dir, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| device_map=None, # keep on CPU for caching | |
| ) | |
| mdl.eval() | |
| _model_cache[model_dir] = (tok, mdl) | |
| print(f"[cache] Loaded {model_dir} on CPU") | |
| return tok, mdl | |
| def build_prompt(paragraph: str) -> str: | |
| """Format the user paragraph into the EXACT structured instruction format.""" | |
| return ( | |
| "Task: You are an expert argument analyst. Identify the role of each sentence within the context of the paragraph/debate/article like a true linguistics and argument expert." | |
| "Number the sentences in the paragraph and tag the role of each one.\n" | |
| "Rules:\n" | |
| "- Do NOT change the text of any sentence.\n" | |
| "- Keep the original order.\n" | |
| "- Output exactly N lines, one per sentence.\n" | |
| "- Each line must be: \"<index> <original sentence> <role>\", where role ∈ {claim, premise, none}.\n" | |
| "\n" | |
| "- Do not add any explanations or extra text after the Nth line.\n" | |
| f"Paragraph:\n{paragraph.strip()}" | |
| ) | |
| # ---------------- JSON Parsing Utilities ---------------- | |
| def get_last_five_words(text: str) -> str: | |
| words = text.strip().split() | |
| return " ".join(words[-5:]) if len(words) >= 3 else " ".join(words) | |
| def extract_role_from_suffix(text_after_match: str) -> str: | |
| text_after_match = text_after_match.strip() | |
| role_words = ['claim', 'premise', 'none'] | |
| for role in role_words: | |
| if text_after_match.lower().startswith(role.lower()): | |
| return role | |
| first_word = text_after_match.split()[0] if text_after_match.split() else "" | |
| for role in role_words: | |
| if first_word.lower().startswith(role.lower()): | |
| return role | |
| return "none" | |
| def parse_numbered_lines(text: str, original_paragraph: str): | |
| results = [] | |
| lines = text.splitlines() | |
| import re | |
| sentences = re.split(r'[.!?]+', original_paragraph.strip()) | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| if not sentences: | |
| return results | |
| last_five_words = get_last_five_words(original_paragraph) | |
| for line in lines: | |
| line = line.strip() | |
| if not line or not line[0].isdigit(): | |
| continue | |
| try: | |
| space_after_idx = line.find(" ") | |
| if space_after_idx == -1: | |
| continue | |
| idx = int(line[:space_after_idx]) | |
| rest = line[space_after_idx + 1:].rstrip() | |
| if last_five_words.lower() in rest.lower(): | |
| match_pos = rest.lower().find(last_five_words.lower()) | |
| if match_pos != -1: | |
| sentence_end = match_pos + len(last_five_words) | |
| sent = rest[:sentence_end].strip() | |
| text_after_match = rest[sentence_end:].strip() | |
| role = "none" | |
| if text_after_match: | |
| text_after_match = text_after_match.lstrip(' .,!?') | |
| role = extract_role_from_suffix(text_after_match) | |
| results.append({"index": idx, "sentence": sent, "role": role}) | |
| break | |
| else: | |
| last_space = rest.rfind(" ") | |
| if last_space == -1: | |
| continue | |
| sent = rest[:last_space].strip() | |
| role_candidate = rest[last_space + 1:].strip().lower() | |
| role = "none" | |
| for valid_role in ['claim', 'premise', 'none']: | |
| if role_candidate.startswith(valid_role): | |
| role = valid_role | |
| break | |
| results.append({"index": idx, "sentence": sent, "role": role}) | |
| except Exception as e: | |
| print(f"Error parsing line '{line}': {e}") | |
| continue | |
| return results | |
| # -------------------------------------------------------- | |
| def generate_text(paragraph, max_tokens, show_parsed): | |
| if not paragraph.strip(): | |
| return "Please enter some text.", "" | |
| tokenizer, model = load_model_to_cpu(MODEL_PATH) | |
| model = model.to("cuda") | |
| formatted_input = build_prompt(paragraph) | |
| messages = [{"role": "user", "content": formatted_input}] | |
| formatted_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(formatted_text, return_tensors="pt").to(model.device) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_tokens), | |
| top_p=TOP_P_DEFAULT, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| if "<|Assistant|>" in full_response: | |
| response = full_response.split("<|Assistant|>")[-1] | |
| response = response.split("<|end▁of▁sentence|>")[0].strip() | |
| else: | |
| new_tokens = outputs[0][inputs.input_ids.shape[-1]:] | |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| parsed = parse_numbered_lines(response, paragraph) | |
| parsed_json = json.dumps(parsed, ensure_ascii=False, indent=2) if show_parsed else "" | |
| return response, parsed_json | |
| def launch_app(): | |
| with gr.Blocks(title="iqasimz/g3 - Argument Role Tagger") as demo: | |
| gr.Markdown("# iqasimz/g3 - Argument Role Tagger") | |
| gr.Markdown("Enter a paragraph, the model will number sentences and assign roles (claim, premise, none).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_para = gr.Textbox( | |
| label="Input Paragraph", | |
| lines=8, | |
| placeholder="Paste your paragraph here..." | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=5000, | |
| value=MAX_NEW_TOKENS_DEFAULT, | |
| step=50, | |
| label="Max New Tokens" | |
| ) | |
| show_parsed = gr.Checkbox(value=True, label="Show parsed JSON") | |
| generate_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Model Output", | |
| lines=15, | |
| show_copy_button=True | |
| ) | |
| parsed_out = gr.Code( | |
| label="Parsed JSON", | |
| language="json" | |
| ) | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[input_para, max_tokens, show_parsed], | |
| outputs=[output_text, parsed_out] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = launch_app() | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", "7860")), | |
| show_error=True | |
| ) |