UNLI / README.md
AdoptedIrelia's picture
Upload README.md
12213e4 verified

Calibrated Latent Uncertainty Estimation

A multimodal Natural Language Inference (NLI) model that estimates the probability of a claim being true given supporting evidence in text, audio, or video form.

Built on Qwen2.5-Omni-3B with LoRA fine-tuning, the model outputs a continuous probability score (0–1) rather than discrete entailment/neutral/contradiction labels.


How It Works

The model is prompted with a claim and a piece of evidence (text sentence, audio file, or video file). It outputs a single special probability token <CON_{idx}> from a vocabulary of 100 tokens uniformly spaced between 0.0 and 1.0. The final score is computed as the expected value over these tokens' softmax probabilities.

This approach allows the model to express graded confidence rather than hard categorical predictions.


Requirements

pip install torch transformers peft qwen-omni-utils

Ensure you have the model weights at ./model and the LoRA adapter at ./lora, or update the paths accordingly.


Quick Start

Text-only Inference

Given a hypothesis sentence and a claim, estimate the probability the claim is entailed:

from utils import *

MODEL_PATH = "./model"
LORA_PATH = "./lora"

processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_PATH)

model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model = PeftModel.from_pretrained(model, LORA_PATH)
model.eval()

prompt = PromptBuilder.build_messages_text(
    claim="The man is speaking loudly.",
    sentence="The man is angry."
)
text = processor.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False)

inputs = processor(text=text, return_tensors="pt", padding=True)
inputs = inputs.to(model.device).to(model.dtype)
outputs = model(**inputs)

answer = extract_answer(outputs.logits.detach(), processor)
print(answer)  # e.g. <answer>0.73</answer>

Audio Inference

Given an audio file and a claim, estimate whether the claim is true:

from utils import *

MODEL_PATH = "./model"
LORA_PATH = "./lora"

processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_PATH)

model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model = PeftModel.from_pretrained(model, LORA_PATH)
model.eval()

prompt = PromptBuilder.build_messages_audio(
    claim="The man is speaking loudly.",
    audio_path="data/audio/audio_00000.wav"
)
USE_AUDIO_IN_VIDEO = True
text = processor.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(prompt, use_audio_in_video=USE_AUDIO_IN_VIDEO)

inputs = processor(
    text=text, audio=audios, images=images, videos=videos,
    return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO
)
inputs = inputs.to(model.device).to(model.dtype)
outputs = model(**inputs)

answer = extract_answer(outputs.logits.detach(), processor)
print(answer)  # e.g. <answer>0.81</answer>

Video Inference

Given a video file and a claim, estimate whether the claim is true:

from utils import *

MODEL_PATH = "./model"
LORA_PATH = "./lora"

processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_PATH)

model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model = PeftModel.from_pretrained(model, LORA_PATH)
model.eval()

prompt = PromptBuilder.build_messages_vl(
    claim="The man is speaking loudly.",
    video_path="data/video/video_00000.mp4"
)
USE_AUDIO_IN_VIDEO = True
text = processor.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(prompt, use_audio_in_video=USE_AUDIO_IN_VIDEO)

inputs = processor(
    text=text, audio=audios, images=images, videos=videos,
    return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO
)
inputs = inputs.to(model.device).to(model.dtype)
outputs = model(**inputs)

answer = extract_answer(outputs.logits.detach(), processor)
print(answer)  # e.g. <answer>0.65</answer>

Output Format

extract_answer returns a string of the form:

<answer>0.734</answer>

The value is a float in [0, 1] representing the estimated probability that the claim is true given the evidence.

Interpretation:

  • Values near 1.0 β†’ the claim is very likely true (entailment)
  • Values near 0.5 β†’ uncertain (neutral)
  • Values near 0.0 β†’ the claim is very likely false (contradiction)

Project Structure

UMUI-model/
β”œβ”€β”€ model/                  # Qwen2.5-Omni-3B base model weights
β”œβ”€β”€ lora/                   # LoRA adapter weights (fine-tuned)
β”‚   β”œβ”€β”€ adapter_config.json
β”‚   └── adapter_model.safetensors
β”œβ”€β”€ utils.py                # PromptBuilder class and extract_answer function
β”œβ”€β”€ prompt.py               # System and user prompt templates
└── README.md

Model Details

Property Value
Base model Qwen/Qwen2.5-Omni-3B
Fine-tuning method LoRA (PEFT)
LoRA rank (r) 16
LoRA alpha 32
LoRA dropout 0.1
Target modules All attention q/k/v/o projections + lm_head
Precision bfloat16
PEFT version 0.17.1

The LoRA adapter targets self-attention projection layers (q_proj, k_proj, v_proj, o_proj) across all transformer layers, as well as the language model head (lm_head), which is necessary to support the custom probability token vocabulary.

API Reference

PromptBuilder

Method Arguments Description
build_messages_text claim: str, sentence: str Build prompt for text-only NLI
build_messages_audio claim: str, audio_path: str Build prompt for audio NLI
build_messages_vl claim: str, video_path: str Build prompt for video NLI (audio+visual)

All methods return a list of message dicts compatible with processor.apply_chat_template.

extract_answer(logits, processor) -> str

Takes the raw model logits (shape [batch, seq_len, vocab_size]) and the processor, and returns a probability score wrapped in <answer>...</answer>.