augustocsc's picture
GPT-2 Base trained on prefix dataset (682K)
5faf2eb verified
# Script para geracao de texto com modelo treinado
# Projeto Seriguela - Geracao interativa de expressoes simbolicas
import argparse
import os
import sys
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from classes.expression import Expression
class ExpressionStoppingCriteria(StoppingCriteria):
"""Stop generation at natural expression boundaries."""
def __init__(self, tokenizer, stop_sequences):
self.tokenizer = tokenizer
self.stop_ids = [tokenizer.encode(seq, add_special_tokens=False)
for seq in stop_sequences]
def __call__(self, input_ids, scores, **kwargs):
# Check if any stop sequence appears in generated text
for stop_ids in self.stop_ids:
if len(stop_ids) > 0 and len(input_ids[0]) >= len(stop_ids):
if input_ids[0][-len(stop_ids):].tolist() == stop_ids:
return True
return False
def parse_args():
parser = argparse.ArgumentParser(description="Generate expressions with a trained model")
parser.add_argument("--model_path", type=str, required=True,
help="Path to model (local or HuggingFace Hub)")
parser.add_argument("--base_model", type=str, default=None,
help="Base model for PEFT (if model_path is adapter)")
# Prompt building arguments
parser.add_argument("--num_vars", type=int, default=3,
help="Number of variables (e.g., 3 for x_1, x_2, x_3)")
parser.add_argument("--operators", type=str, default="+,-,*,/,sin,cos",
help="Comma-separated operators (e.g., '+,-,*,/,sin,cos,log,sqrt,exp')")
parser.add_argument("--constants", type=str, default="C",
help="Constant symbol (default: C)")
parser.add_argument("--format", type=str, default="infix", choices=["infix", "prefix"],
help="Expression format (infix or prefix)")
# Custom prompt
parser.add_argument("--custom_prompt", type=str, default=None,
help="Use a custom prompt instead of building one")
# Generation parameters
parser.add_argument("--num_generations", type=int, default=5,
help="Number of expressions to generate")
parser.add_argument("--max_new_tokens", type=int, default=64,
help="Maximum new tokens to generate")
parser.add_argument("--temperature", type=float, default=0.7,
help="Sampling temperature (higher = more diverse)")
parser.add_argument("--top_p", type=float, default=0.9,
help="Top-p sampling parameter")
parser.add_argument("--top_k", type=int, default=50,
help="Top-k sampling parameter")
# Behavior
parser.add_argument("--validate", action="store_true",
help="Validate generated expressions")
parser.add_argument("--interactive", action="store_true",
help="Run in interactive mode")
parser.add_argument("--device", type=str, default="auto",
help="Device to use (auto, cuda, cpu)")
parser.add_argument("--seed", type=int, default=None,
help="Random seed for reproducibility")
return parser.parse_args()
def build_prompt(num_vars: int, operators: list, constants: str = "C",
format_type: str = "infix") -> str:
"""Build a prompt for expression generation."""
# Build variables string
vars_list = [f"x_{i}" for i in range(1, num_vars + 1)]
vars_str = ", ".join(vars_list)
# Build operators string
ops_str = ", ".join(operators)
# Build prompt based on format
if format_type == "infix":
prompt = f"""Variables: {vars_str}
Operators: {ops_str}
Constants: {constants}
Expression: <|startofex|>"""
else: # prefix
prompt = f"""Variables: {vars_str}
Operators: {ops_str}
Constants: {constants}
Prefix Expression: <|startofex|>"""
return prompt
def load_model_and_tokenizer(model_path: str, base_model: str = None, device: str = "auto"):
"""Load model and tokenizer."""
print(f"Loading model from: {model_path}")
# Determine device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check if this is a PEFT model
is_peft = os.path.exists(os.path.join(model_path, "adapter_config.json")) if os.path.isdir(model_path) else False
if is_peft or base_model:
# Load base model first
base = base_model or "gpt2"
print(f"Loading base model: {base}")
model = AutoModelForCausalLM.from_pretrained(base)
model.resize_token_embeddings(len(tokenizer))
# Load PEFT adapter
print("Loading PEFT adapter...")
model = PeftModel.from_pretrained(model, model_path)
model = model.merge_and_unload() # Merge for faster inference
else:
# Load full model
model = AutoModelForCausalLM.from_pretrained(model_path)
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
model.eval()
return model, tokenizer, device
def generate_expressions(model, tokenizer, prompt: str, device: str,
num_generations: int = 5, max_new_tokens: int = 64,
temperature: float = 0.7, top_p: float = 0.9,
top_k: int = 50):
"""Generate expressions from a prompt."""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get special token IDs - prefer <|endofex|> as EOS
end_token_id = tokenizer.convert_tokens_to_ids("<|endofex|>")
if end_token_id == tokenizer.unk_token_id:
print("Warning: <|endofex|> not in tokenizer, using default eos_token_id")
end_token_id = tokenizer.eos_token_id
# Create stopping criteria to stop at natural expression boundaries (backup)
stop_sequences = ["\nvars:", "\nVariables:", "\nOperators:", "\n\n"]
stopping_criteria = StoppingCriteriaList([
ExpressionStoppingCriteria(tokenizer, stop_sequences)
])
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
num_return_sequences=num_generations,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=end_token_id, # Use <|endofex|> as EOS
stopping_criteria=stopping_criteria, # Keep as backup
)
generated = tokenizer.batch_decode(outputs, skip_special_tokens=False)
return generated
def extract_expression(output: str) -> str:
"""Extract the expression from generated output."""
# Try marker-based first
start_marker = "<|startofex|>"
end_marker = "<|endofex|>"
if start_marker in output and end_marker in output:
start_idx = output.find(start_marker) + len(start_marker)
end_idx = output.find(end_marker)
if start_idx < end_idx:
return output[start_idx:end_idx].strip()
# Fallback: Extract first complete expression after start marker
if start_marker in output:
start_idx = output.find(start_marker) + len(start_marker)
remaining = output[start_idx:].strip()
# Split at common boundaries
for boundary in ["\nvars:", "\nVariables:", "\nOperators:", "\n\n", "<|endoftext|>"]:
if boundary in remaining:
remaining = remaining.split(boundary)[0].strip()
break
# Remove any trailing incomplete text - take just the first line
remaining = remaining.split("\n")[0].strip()
# Limit length if unreasonably long
if len(remaining) > 150:
remaining = remaining[:150]
return remaining
# Last resort: look for "expr:" or "Expression:" pattern
match = re.search(r'(?:expr|Expression):\s*(.+?)(?:\n|$)', output, re.IGNORECASE)
if match:
return match.group(1).strip()
# Give up: return first line, limited length
first_line = output.strip().split("\n")[0]
return first_line[:100] if len(first_line) > 100 else first_line
def validate_expression(expr_str: str, is_prefix: bool = False) -> dict:
"""Validate an expression."""
result = {
"valid": False,
"error": None,
"sympy_str": None
}
if not expr_str:
result["error"] = "Empty expression"
return result
try:
expr = Expression(expr_str, is_prefix=is_prefix)
result["valid"] = True
result["sympy_str"] = expr.sympy_str()
except Exception as e:
result["error"] = str(e)
return result
def print_generation_result(idx: int, expr_str: str, validation: dict = None):
"""Print a formatted generation result."""
print(f"\n[{idx + 1}] {expr_str}")
if validation:
if validation["valid"]:
print(f" Status: VALID")
if validation["sympy_str"] != expr_str:
print(f" Sympy: {validation['sympy_str']}")
else:
print(f" Status: INVALID - {validation['error']}")
def interactive_mode(model, tokenizer, device, args):
"""Run in interactive mode."""
print("\n" + "="*60)
print("SERIGUELA - Interactive Expression Generator")
print("="*60)
print("Commands:")
print(" /vars N - Set number of variables (e.g., /vars 3)")
print(" /ops +,-,* - Set operators (e.g., /ops +,-,*,sin)")
print(" /format X - Set format (infix or prefix)")
print(" /temp T - Set temperature (e.g., /temp 0.8)")
print(" /n N - Set number of generations (e.g., /n 10)")
print(" /prompt - Show current prompt")
print(" /gen - Generate with current settings")
print(" /custom TEXT - Use custom prompt")
print(" /quit - Exit")
print("="*60)
# Current settings
settings = {
"num_vars": args.num_vars,
"operators": args.operators.split(","),
"format": args.format,
"temperature": args.temperature,
"num_generations": args.num_generations,
"custom_prompt": None
}
is_prefix = settings["format"] == "prefix"
while True:
try:
user_input = input("\n> ").strip()
except (EOFError, KeyboardInterrupt):
print("\nGoodbye!")
break
if not user_input:
continue
if user_input.startswith("/"):
parts = user_input.split(maxsplit=1)
cmd = parts[0].lower()
arg = parts[1] if len(parts) > 1 else None
if cmd == "/quit" or cmd == "/exit":
print("Goodbye!")
break
elif cmd == "/vars" and arg:
try:
settings["num_vars"] = int(arg)
print(f"Variables set to {settings['num_vars']}")
except ValueError:
print("Invalid number")
elif cmd == "/ops" and arg:
settings["operators"] = [op.strip() for op in arg.split(",")]
print(f"Operators set to: {settings['operators']}")
elif cmd == "/format" and arg:
if arg.lower() in ["infix", "prefix"]:
settings["format"] = arg.lower()
is_prefix = settings["format"] == "prefix"
print(f"Format set to {settings['format']}")
else:
print("Invalid format. Use 'infix' or 'prefix'")
elif cmd == "/temp" and arg:
try:
settings["temperature"] = float(arg)
print(f"Temperature set to {settings['temperature']}")
except ValueError:
print("Invalid temperature")
elif cmd == "/n" and arg:
try:
settings["num_generations"] = int(arg)
print(f"Number of generations set to {settings['num_generations']}")
except ValueError:
print("Invalid number")
elif cmd == "/prompt":
prompt = build_prompt(
settings["num_vars"],
settings["operators"],
"C",
settings["format"]
)
print(f"\nCurrent prompt:\n{prompt}")
elif cmd == "/custom" and arg:
settings["custom_prompt"] = arg
print(f"Custom prompt set")
elif cmd == "/gen":
# Generate
if settings["custom_prompt"]:
prompt = settings["custom_prompt"]
else:
prompt = build_prompt(
settings["num_vars"],
settings["operators"],
"C",
settings["format"]
)
print(f"\nGenerating {settings['num_generations']} expressions...")
print("-"*40)
outputs = generate_expressions(
model, tokenizer, prompt, device,
num_generations=settings["num_generations"],
temperature=settings["temperature"],
top_p=args.top_p,
top_k=args.top_k,
max_new_tokens=args.max_new_tokens
)
valid_count = 0
for i, output in enumerate(outputs):
expr_str = extract_expression(output)
validation = validate_expression(expr_str, is_prefix)
print_generation_result(i, expr_str, validation)
if validation["valid"]:
valid_count += 1
print("-"*40)
print(f"Valid: {valid_count}/{len(outputs)}")
else:
print(f"Unknown command: {cmd}")
else:
# Treat as custom prompt and generate
prompt = user_input if "<|startofex|>" in user_input else user_input + " <|startofex|>"
print(f"\nGenerating {settings['num_generations']} expressions...")
print("-"*40)
outputs = generate_expressions(
model, tokenizer, prompt, device,
num_generations=settings["num_generations"],
temperature=settings["temperature"],
top_p=args.top_p,
top_k=args.top_k,
max_new_tokens=args.max_new_tokens
)
valid_count = 0
for i, output in enumerate(outputs):
expr_str = extract_expression(output)
validation = validate_expression(expr_str, is_prefix) if args.validate else None
print_generation_result(i, expr_str, validation)
if validation and validation["valid"]:
valid_count += 1
if args.validate:
print("-"*40)
print(f"Valid: {valid_count}/{len(outputs)}")
def main():
args = parse_args()
# Set seed if provided
if args.seed is not None:
torch.manual_seed(args.seed)
# Load model
model, tokenizer, device = load_model_and_tokenizer(
args.model_path, args.base_model, args.device
)
# Interactive mode
if args.interactive:
interactive_mode(model, tokenizer, device, args)
return
# Build or use custom prompt
if args.custom_prompt:
prompt = args.custom_prompt
else:
operators = [op.strip() for op in args.operators.split(",")]
prompt = build_prompt(
args.num_vars,
operators,
args.constants,
args.format
)
print("\n" + "="*60)
print("SERIGUELA - Expression Generator")
print("="*60)
print(f"Model: {args.model_path}")
print(f"Format: {args.format}")
print(f"Temperature: {args.temperature}")
print("-"*60)
print("Prompt:")
print(prompt)
print("-"*60)
# Generate
is_prefix = args.format == "prefix"
outputs = generate_expressions(
model, tokenizer, prompt, device,
num_generations=args.num_generations,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k
)
print(f"\nGenerated {len(outputs)} expressions:")
print("-"*60)
valid_count = 0
for i, output in enumerate(outputs):
expr_str = extract_expression(output)
validation = validate_expression(expr_str, is_prefix) if args.validate else None
print_generation_result(i, expr_str, validation)
if validation and validation["valid"]:
valid_count += 1
if args.validate:
print("-"*60)
print(f"\nSummary: {valid_count}/{len(outputs)} valid expressions ({valid_count/len(outputs)*100:.1f}%)")
print("="*60)
if __name__ == "__main__":
main()