#!/usr/bin/env python """ Quick smoke-test script: load a fine-tuned adapter and generate a single prediction. Use the same prompt format as finetune_hx_unsloth.py to keep behavior aligned. """ from __future__ import annotations import argparse import os from pathlib import Path import pandas as pd import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig FEW_SHOT_EXAMPLES = [ { "text": ( "in britain the prime minister a childless woman london mayor a moslem invader in command " "of london police a lesbian in command of london fire brigade a fat woman scotland first " "minister a childless woman" ), "label": "Hate", }, { "text": ( "the non partisan congressional budget office cbo has released their report on sens " "lindsey graham and dick durbin the dream act the report shows that their legislation " "would add billions to the deficit over a ten year period and lots more immigrant scum" ), "label": "Hate", }, { "text": ( "i live and work with many legal mexican immigrants who are great citizens and trump supporters " "they have no problem with deporting illegals maga" ), "label": "Not Hate", }, ] def build_prompt(post: str) -> str: post = (post or "").strip() sections = [ ": You are a content moderation assistant. Label each post as hatespeech, offensive, or normal according to the HateXplain policy. Classify the following social media post. Respond with exactly one label: hatespeech, offensive, or normal." ] for idx, example in enumerate(FEW_SHOT_EXAMPLES, start=1): sections.append(f"Example {idx}:") sections.append(f": {example['text']}") sections.append(f": {example['label']}") sections.append("") sections.append(f": {post}") sections.append(":") return "\n".join(sections).strip() def fetch_post(args: argparse.Namespace) -> str: if args.text: return args.text df = pd.read_csv(args.dataset_path) if args.sample_index < 0 or args.sample_index >= len(df): raise IndexError(f"sample_index {args.sample_index} outside dataset size {len(df)}") return df.iloc[args.sample_index][args.text_col] DEFAULT_ADAPTER_DIR = Path(__file__).resolve().parent.parent / "adapter" DEFAULT_BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" def load_adapter( adapter_dir: str, base_model: str | None = None, load_in_4bit: bool = True, ): adapter_path = Path(adapter_dir) if base_model is None: base_model = DEFAULT_BASE_MODEL tokenizer = AutoTokenizer.from_pretrained( adapter_path, use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 quantization_config = None if load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) model = AutoModelForCausalLM.from_pretrained( base_model, device_map="auto", torch_dtype=torch_dtype, trust_remote_code=True, quantization_config=quantization_config, ) model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False) model.eval() return model, tokenizer def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Generate a single HX prediction using a fine-tuned adapter.") parser.add_argument( "--adapter-dir", default=str(DEFAULT_ADAPTER_DIR), help="Path to the trained adapter directory.", ) parser.add_argument( "--base-model", default=None, help="Optional base model path/ID; if omitted we rely on adapter_config metadata.", ) parser.add_argument( "--dataset-path", default="hx.csv", help="CSV to sample from when --text is not provided.", ) parser.add_argument("--text-col", default="text", help="Column containing the post text.") parser.add_argument("--sample-index", type=int, default=0, help="Row index to read from the dataset.") parser.add_argument("--text", default=None, help="Override text instead of pulling from the dataset.") parser.add_argument("--max-new-tokens", type=int, default=8, help="Generation length.") parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature (0 for greedy).") parser.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling cutoff.") return parser.parse_args() def main() -> None: args = parse_args() post = fetch_post(args) prompt = build_prompt(post) model, tokenizer = load_adapter( adapter_dir=args.adapter_dir, base_model=args.base_model, load_in_4bit=True, ) tokenizer.padding_side = "left" tokenizer.truncation_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer(prompt, return_tensors="pt").to(model.device) model.eval() with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, do_sample=args.temperature > 0, eos_token_id=tokenizer.eos_token_id, ) completion = tokenizer.decode(output[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip() print("Prompt:\n", prompt) print("\nModel completion:\n", completion) if __name__ == "__main__": os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") main()