|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer |
|
|
import torch |
|
|
import warnings |
|
|
import re |
|
|
import logging |
|
|
import argparse |
|
|
from znum import Znum, Topsis, Promethee, Beast |
|
|
from helpers.utils import SYSTEM_PROMPT, DEFAULT_QUERY |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
A_MAP = { |
|
|
1: [2, 3, 3, 4], |
|
|
2: [4, 5, 5, 6], |
|
|
3: [6, 7, 7, 8], |
|
|
4: [8, 9, 9, 10], |
|
|
5: [10, 11, 11, 12], |
|
|
} |
|
|
|
|
|
B_MAP = { |
|
|
1: [0.2, 0.3, 0.3, 0.4], |
|
|
2: [0.3, 0.4, 0.4, 0.5], |
|
|
3: [0.4, 0.5, 0.5, 0.6], |
|
|
4: [0.5, 0.6, 0.6, 0.7], |
|
|
5: [0.6, 0.7, 0.7, 0.8], |
|
|
} |
|
|
|
|
|
|
|
|
def parse_znum_pair(pair_str: str) -> Znum | None: |
|
|
"""Convert 'N:M' string to Znum object using A_MAP and B_MAP (abs value for A).""" |
|
|
try: |
|
|
parts = pair_str.strip().split(':') |
|
|
if len(parts) != 2: |
|
|
return None |
|
|
a_val = abs(int(parts[0])) |
|
|
b_val = int(parts[1]) |
|
|
if a_val not in A_MAP or b_val not in B_MAP: |
|
|
logger.warning(f"Invalid Z-number pair: {pair_str}") |
|
|
return None |
|
|
return Znum(A_MAP[a_val], B_MAP[b_val]) |
|
|
except (ValueError, KeyError) as e: |
|
|
logger.warning(f"Failed to parse Z-number pair '{pair_str}': {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def parse_markdown_table(text: str) -> dict: |
|
|
"""Parse markdown table from model output into structured dict.""" |
|
|
lines = [l.strip() for l in text.strip().split('\n') if l.strip() and '|' in l] |
|
|
lines = [l for l in lines if not re.match(r'^\|[-:\s|]+\|$', l)] |
|
|
|
|
|
if len(lines) < 4: |
|
|
logger.warning("Table has fewer than expected rows") |
|
|
return {} |
|
|
|
|
|
def split_row(row: str) -> list: |
|
|
cells = [c.strip() for c in row.split('|')] |
|
|
return [c for c in cells if c] |
|
|
|
|
|
headers = split_row(lines[0]) |
|
|
criteria = headers[1:] if headers else [] |
|
|
|
|
|
types_row = split_row(lines[1]) |
|
|
types = types_row[1:] if len(types_row) > 1 else [] |
|
|
|
|
|
weights_row = split_row(lines[-1]) |
|
|
weights = weights_row[1:] if len(weights_row) > 1 else [] |
|
|
|
|
|
alternatives = {} |
|
|
for line in lines[2:-1]: |
|
|
row = split_row(line) |
|
|
if row: |
|
|
alt_name = row[0] |
|
|
values = row[1:] |
|
|
alternatives[alt_name] = values |
|
|
|
|
|
result = { |
|
|
'criteria': criteria, |
|
|
'types': types, |
|
|
'alternatives': alternatives, |
|
|
'weights': weights |
|
|
} |
|
|
return result |
|
|
|
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
message="Chat template .*", |
|
|
category=UserWarning, |
|
|
) |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Z-number decision matrix extraction and MCDM') |
|
|
parser.add_argument('--method', type=str, choices=['topsis', 'promethee'], default='topsis', |
|
|
help='MCDM method to use (default: topsis)') |
|
|
parser.add_argument('--query', '-q', type=str, default=DEFAULT_QUERY, |
|
|
help='Decision query to process') |
|
|
args = parser.parse_args() |
|
|
|
|
|
qconfig = BitsAndBytesConfig( |
|
|
load_in_8bit=True, |
|
|
) |
|
|
|
|
|
model_name = "nuriyev/Qwen3-4B-znum-decision-matrix" |
|
|
print(f"Loading model: {model_name}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
device_map="auto", |
|
|
dtype=torch.bfloat16, |
|
|
|
|
|
) |
|
|
print("Model loaded successfully!\n") |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": args.query}, |
|
|
] |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
streamer = TextStreamer(tokenizer, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
output_ids = model.generate(**inputs, max_length=8192, streamer=streamer, temperature=1) |
|
|
|
|
|
|
|
|
generated_ids = output_ids[0][inputs['input_ids'].shape[1]:] |
|
|
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
logger.info("Parsing decision matrix from model output...") |
|
|
matrix = parse_markdown_table(generated_text) |
|
|
|
|
|
if matrix: |
|
|
logger.info(f"Criteria: {matrix['criteria']}") |
|
|
logger.info(f"Types: {matrix['types']}") |
|
|
|
|
|
|
|
|
znum_weights = [parse_znum_pair(w) for w in matrix['weights']] |
|
|
logger.info("Weights as Znum:") |
|
|
for i, (name, zw) in enumerate(zip(matrix['criteria'], znum_weights)): |
|
|
logger.info(f" {name}: {zw}") |
|
|
|
|
|
|
|
|
znum_alternatives = {} |
|
|
for alt_name, values in matrix['alternatives'].items(): |
|
|
znum_values = [parse_znum_pair(v) for v in values] |
|
|
znum_alternatives[alt_name] = znum_values |
|
|
logger.info(f"Alternative '{alt_name}' as Znum:") |
|
|
for i, (crit, zv) in enumerate(zip(matrix['criteria'], znum_values)): |
|
|
logger.info(f" {crit}: {zv}") |
|
|
|
|
|
|
|
|
matrix['znum_weights'] = znum_weights |
|
|
matrix['znum_alternatives'] = znum_alternatives |
|
|
|
|
|
|
|
|
criteria_types = [ |
|
|
Beast.CriteriaType.BENEFIT if t.lower() == 'benefit' else Beast.CriteriaType.COST |
|
|
for t in matrix['types'] |
|
|
] |
|
|
|
|
|
|
|
|
alt_names = list(znum_alternatives.keys()) |
|
|
alt_rows = [znum_alternatives[name] for name in alt_names] |
|
|
table = [znum_weights] + alt_rows + [criteria_types] |
|
|
|
|
|
|
|
|
logger.info(f"\nApplying {args.method.upper()} method...") |
|
|
|
|
|
if args.method == 'topsis': |
|
|
solver = Topsis(table) |
|
|
else: |
|
|
solver = Promethee(table) |
|
|
|
|
|
solver.solve() |
|
|
|
|
|
|
|
|
logger.info(f"\n{'='*50}") |
|
|
logger.info(f"RESULTS ({args.method.upper()})") |
|
|
logger.info(f"{'='*50}") |
|
|
logger.info(f"Best alternative: {alt_names[solver.index_of_best_alternative]}") |
|
|
logger.info(f"Worst alternative: {alt_names[solver.index_of_worst_alternative]}") |
|
|
logger.info(f"\nRanking (best to worst):") |
|
|
for rank, idx in enumerate(solver.ordered_indices, 1): |
|
|
logger.info(f" {rank}. {alt_names[idx]}") |
|
|
else: |
|
|
logger.error("Failed to parse decision matrix") |