lora-qa / app.py
aayush226's picture
Update app.py
095ee2a verified
import math
import html
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DistilBertModel
class LORALinear(nn.Module):
def __init__(self, base_linear: nn.Linear, r=8, alpha=16, dropout=0.1):
super().__init__()
self.in_features = base_linear.in_features
self.out_features = base_linear.out_features
self.r = r
self.alpha = alpha
self.scaling = alpha / r
# tie to base weights
self.weight = base_linear.weight
self.bias = base_linear.bias
for p in (self.weight, self.bias):
if p is not None:
p.requires_grad = False
dev = self.weight.device
self.A = nn.Parameter(torch.empty(r, self.in_features, device=dev))
self.B = nn.Parameter(torch.zeros(self.out_features, r, device=dev))
nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x):
base = F.linear(x, self.weight, self.bias)
delta = F.linear(self.dropout(x), self.A)
delta = F.linear(delta, self.B)
return base + self.scaling * delta
DistilBertSelfAttention = type(
DistilBertModel.from_pretrained("distilbert-base-uncased").transformer.layer[0].attention
)
def add_lora_to_distilbert(model, r=8, alpha=16, lora_dropout=0.1, targets=("q_lin", "v_lin")):
replaced = 0
for m in model.modules():
if isinstance(m, DistilBertSelfAttention):
for name in targets:
base = getattr(m, name, None)
if isinstance(base, nn.Linear):
setattr(m, name, LORALinear(base, r=r, alpha=alpha, dropout=lora_dropout))
replaced += 1
print(f"LoRA injected into {replaced} Linear(s): {targets}")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
qa = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased").to(device)
add_lora_to_distilbert(qa, r=8, alpha=16, lora_dropout=0.1)
state = torch.load("distilbert_squad_lora.pt", map_location=device)
_ = qa.load_state_dict(state, strict=False)
qa.eval()
def softmax(x: torch.Tensor) -> torch.Tensor:
x = x - x.max()
ex = torch.exp(x)
return ex / ex.sum()
def hue_for_prob(p: float) -> int:
"""
Map probability in [0,1] to hue in [20,120] (orange -> green).
"""
p = max(0.0, min(1.0, float(p)))
return int(20 + 100 * p)
def topk_spans(question, context, top_k=3, max_answer_len=30):
enc = tokenizer(
question,
context,
return_tensors="pt",
truncation="only_second",
max_length=384,
return_offsets_mapping=True,
).to(device)
with torch.no_grad():
out = qa(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
)
start_logits = out.start_logits[0].cpu()
end_logits = out.end_logits[0].cpu()
# sequence ids come from the encoding
seq_ids = enc.sequence_ids(0) # list with None/0/1
offsets = enc["offset_mapping"][0].cpu().tolist()
# Only consider context tokens (seq id == 1)
ctx_idx = [i for i, s in enumerate(seq_ids) if s == 1]
if not ctx_idx:
return [], html.escape(context), ""
# Build candidate spans (score = start_logit + end_logit)
cands = []
for i_local, i in enumerate(ctx_idx):
for j in ctx_idx[i_local : i_local + int(max_answer_len)]:
if j < i:
continue
s_log, e_log = start_logits[i], end_logits[j]
cs, ce = offsets[i], offsets[j]
if cs[0] is None or ce[1] is None:
continue
score = float(s_log + e_log)
cands.append({
"score": score,
"start_tok": i,
"end_tok": j,
"start_char": cs[0],
"end_char": ce[1],
})
if not cands:
return [], html.escape(context), ""
# Sort by score desc
cands.sort(key=lambda x: x["score"], reverse=True)
# Convert scores of a small pool to probabilities
pool = cands[: max(50, int(top_k) * 10)]
scores = torch.tensor([c["score"] for c in pool])
probs = softmax(scores).tolist()
for c, p in zip(pool, probs):
c["prob"] = p
# Greedily pick non-overlapping top_k
chosen, used = [], []
for c in pool:
s, e = c["start_char"], c["end_char"]
if any(not (e <= us or s >= ue) for (us, ue) in used):
continue
used.append((s, e))
c["text"] = context[s:e]
chosen.append(c)
if len(chosen) >= int(top_k):
break
# Build highlighted HTML (confidence mapped to hue)
chosen_sorted = sorted(chosen, key=lambda x: x["start_char"])
html_out, prev = [], 0
for c in chosen_sorted:
s, e, p = c["start_char"], c["end_char"], c["prob"]
hue = hue_for_prob(p) # 20..120
tip = f"p={p:.3f}; tokens=({c['start_tok']},{c['end_tok']})"
html_out.append(html.escape(context[prev:s]))
span = html.escape(context[s:e])
html_out.append(
f"<span title='{tip}' "
f"style='background-color: hsl({hue}, 95%, 70%); "
f"border-radius: 3px; padding: 2px 2px;'>{span}</span>"
)
prev = e
html_out.append(html.escape(context[prev:]))
rows = [
[c["text"], round(c["prob"], 4), c["start_tok"], c["end_tok"], round(c["score"], 3)]
for c in chosen
]
top1 = chosen[0]["text"] if chosen else ""
return rows, "".join(html_out), top1
EXAMPLES = [
[
"Who set the national goal of landing a man on the Moon?",
"The Apollo program was the third United States human spaceflight program carried out by NASA, which accomplished landing the first humans on the Moon from 1969 to 1972. Apollo ran from 1961 to 1972 and was dedicated to President John F. Kennedy's national goal of 'landing a man on the Moon and returning him safely to the Earth' by the end of the 1960s."
],
[
"What is Einstein best known for?",
"Albert Einstein developed the theory of relativity, one of the two pillars of modern physics. His work is also known for its influence on the philosophy of science. Einstein is best known for his mass–energy equivalence formula E=mc², which has been called 'the world's most famous equation'."
],
[
"How many countries does the Amazon rainforest spread across?",
"The Amazon rainforest, also known in English as Amazonia, is a moist broadleaf tropical rainforest in the Amazon biome that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 km², of which 5,500,000 km² are covered by the rainforest. This region includes territory belonging to nine nations."
],
]
with gr.Blocks(theme=gr.themes.Soft(), css="mark{background:#ffd54f;}") as demo:
gr.Markdown("## 🟡 DistilBERT + LoRA SQuAD QA\nEnter a context and a question. The model returns **Top-K spans** with **color-coded confidence** and **tooltips** (hover).")
with gr.Row():
with gr.Column(scale=1):
q = gr.Textbox(label="Question")
ctx = gr.Textbox(label="Context", lines=10)
with gr.Row():
topk = gr.Slider(1, 5, value=3, step=1, label="Top-K Spans")
maxlen = gr.Slider(5, 60, value=30, step=1, label="Max Answer Length (tokens)")
btn = gr.Button("Get Answer", variant="primary")
gr.Examples(examples=EXAMPLES, inputs=[q, ctx], label="Try these:")
with gr.Column(scale=1):
top1 = gr.Textbox(label="Top-1 Predicted Answer")
table = gr.Dataframe(
headers=["answer", "prob", "start_tok", "end_tok", "score"],
datatype=["str", "number", "number", "number", "number"],
label="Top-K candidates"
)
highlighted = gr.HTML(label="Highlighted Context (hover for details)")
def ui_wrapper(question, context, top_k, max_answer_len):
rows, highlighted_html, top_answer = topk_spans(
question, context, int(top_k), int(max_answer_len)
)
return top_answer, rows, highlighted_html
btn.click(ui_wrapper, inputs=[q, ctx, topk, maxlen], outputs=[top1, table, highlighted])
if __name__ == "__main__":
demo.launch()