devflow / app.py
bhsinghgrid's picture
Upload 27 files
f8437ec verified
import copy
import json
import os
from datetime import datetime
import gradio as gr
import torch
from config import CONFIG
from inference import load_model, run_inference, _build_tokenizers, _resolve_device
RESULTS_DIR = "generated_results"
os.makedirs(RESULTS_DIR, exist_ok=True)
def discover_checkpoints():
found = []
for root in ("ablation_results", "results7", "results"):
if not os.path.isdir(root):
continue
for entry in sorted(os.listdir(root)):
ckpt = os.path.join(root, entry, "best_model.pt")
if not os.path.exists(ckpt):
continue
found.append({
"label": f"{entry} [{root}]",
"path": ckpt,
"experiment": entry,
"root": root,
})
return found
def default_checkpoint_label():
checkpoints = discover_checkpoints()
if not checkpoints:
return None
for item in checkpoints:
if item["path"].endswith("ablation_results/T4/best_model.pt"):
return item["label"]
return checkpoints[0]["label"]
def checkpoint_map():
return {item["label"]: item for item in discover_checkpoints()}
def infer_model_type(experiment_name: str, root: str = "") -> str:
if root == "ablation_results":
return "d3pm_cross_attention"
if experiment_name.startswith("d3pm_cross_attention"):
return "d3pm_cross_attention"
if experiment_name.startswith("d3pm_encoder_decoder"):
return "d3pm_encoder_decoder"
if experiment_name.startswith("baseline_cross_attention"):
return "baseline_cross_attention"
if experiment_name.startswith("baseline_encoder_decoder"):
return "baseline_encoder_decoder"
return CONFIG["model_type"]
def infer_include_negative(experiment_name: str, root: str = "") -> bool:
if root == "ablation_results":
return False
if "_neg_True" in experiment_name:
return True
if "_neg_False" in experiment_name:
return False
return CONFIG["data"]["include_negative_examples"]
def build_runtime_cfg(ckpt_path: str):
experiment = os.path.basename(os.path.dirname(ckpt_path))
root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
cfg = copy.deepcopy(CONFIG)
cfg["model_type"] = infer_model_type(experiment, root=root)
cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
t_val = int(experiment[1:])
cfg["model"]["diffusion_steps"] = t_val
cfg["inference"]["num_steps"] = t_val
device = _resolve_device(cfg)
return cfg, device, experiment
def load_selected_model(checkpoint_label):
mapping = checkpoint_map()
if checkpoint_label not in mapping:
raise gr.Error("Selected checkpoint was not found. Refresh the dropdown.")
ckpt_path = mapping[checkpoint_label]["path"]
cfg, device, experiment = build_runtime_cfg(ckpt_path)
model, cfg = load_model(ckpt_path, cfg, device)
src_tok, tgt_tok = _build_tokenizers(cfg)
bundle = {
"ckpt_path": ckpt_path,
"experiment": experiment,
"device": str(device),
"cfg": cfg,
"model": model,
"src_tok": src_tok,
"tgt_tok": tgt_tok,
}
model_info = {
"checkpoint": ckpt_path,
"experiment": experiment,
"model_type": cfg["model_type"],
"include_negatives": cfg["data"]["include_negative_examples"],
"device": str(device),
"max_seq_len": cfg["model"]["max_seq_len"],
"diffusion_steps": cfg["model"]["diffusion_steps"],
"d_model": cfg["model"]["d_model"],
"n_layers": cfg["model"]["n_layers"],
"n_heads": cfg["model"]["n_heads"],
}
status = f"Loaded `{experiment}` on `{device}`."
return bundle, status, model_info, cfg["inference"]["num_steps"]
def apply_preset(preset_name):
presets = {
"Manual": (0.70, 40, 1.20, 0.0, 64),
"Literal": (0.60, 20, 1.25, 0.0, 64),
"Balanced": (0.70, 40, 1.20, 0.0, 64),
"Creative": (0.85, 80, 1.20, 0.2, 64),
}
return presets.get(preset_name, presets["Balanced"])
def task_notes_md():
return """
### Task Notes
**Task 1: KV Cache**
- Benchmark encoder caching vs standard generation.
- Best for engineering evaluation, not language quality evaluation.
**Task 2: Attention + Drift**
- Shows internal attention maps and output stabilization over diffusion steps.
- Useful for diagnostics and mentor discussion of model behavior.
**Task 3: Concept Vectors**
- Experimental PCA steering over decoder hidden states.
- Current outputs are exploratory, not strong semantic evidence yet.
**Task 4: Step Ablation**
- Requires retraining separate checkpoints for each diffusion step count.
- Use this UI for generation only; ablation analysis runs from `analysis/run_analysis.py`.
**Task 5: Quality Guidance**
- Advanced experimental feature in the analysis pipeline.
- Not exposed in this UI because the current evidence is still under validation.
"""
def save_generation(experiment, record):
ts = datetime.now().strftime("%Y%m%d")
path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
existing = []
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
existing = json.load(f)
existing.append(record)
with open(path, "w", encoding="utf-8") as f:
json.dump(existing, f, ensure_ascii=False, indent=2)
return path
def clean_generated_text(text: str, max_consecutive: int = 2, max_occurrence_ratio: float = 0.15) -> str:
"""
Lightweight cleanup for repetitive diffusion outputs.
Keeps Sanskrit tokens but trims pathological token loops.
"""
text = " ".join(text.split())
if not text:
return text
tokens = text.split()
cleaned = []
# 1) Limit consecutive token repetitions.
prev = None
run = 0
for tok in tokens:
if tok == prev:
run += 1
else:
prev = tok
run = 1
if run <= max_consecutive:
cleaned.append(tok)
# 2) Limit global over-dominant tokens (common in collapse cases).
if cleaned:
max_occ = max(3, int(len(cleaned) * max_occurrence_ratio))
counts = {}
filtered = []
for tok in cleaned:
c = counts.get(tok, 0) + 1
counts[tok] = c
if c <= max_occ:
filtered.append(tok)
cleaned = filtered
out = " ".join(cleaned)
out = out.replace(" ।", "।").replace(" ॥", "॥")
out = " ".join(out.split())
return out
def generate_from_ui(
model_bundle,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
):
if not model_bundle:
raise gr.Error("Load a model first.")
if not input_text.strip():
raise gr.Error("Enter input text first.")
cfg = copy.deepcopy(model_bundle["cfg"])
cfg["inference"]["temperature"] = float(temperature)
cfg["inference"]["top_k"] = int(top_k)
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
cfg["inference"]["num_steps"] = int(num_steps)
src_tok = model_bundle["src_tok"]
tgt_tok = model_bundle["tgt_tok"]
device = torch.device(model_bundle["device"])
input_ids = torch.tensor(
[src_tok.encode(input_text.strip())],
dtype=torch.long,
device=device,
)
out = run_inference(model_bundle["model"], input_ids, cfg)
clean = [x for x in out[0].tolist() if x > 4]
raw_output_text = tgt_tok.decode(clean).strip()
output_text = clean_generated_text(raw_output_text) if clean_output else raw_output_text
if not output_text:
output_text = "(empty output)"
record = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"experiment": model_bundle["experiment"],
"checkpoint": model_bundle["ckpt_path"],
"input_text": input_text,
"raw_output_text": raw_output_text,
"output_text": output_text,
"clean_output": bool(clean_output),
"temperature": float(temperature),
"top_k": int(top_k),
"repetition_penalty": float(repetition_penalty),
"diversity_penalty": float(diversity_penalty),
"num_steps": int(num_steps),
}
log_path = save_generation(model_bundle["experiment"], record)
status = f"Generated with `{model_bundle['experiment']}`. Saved to `{log_path}`."
return output_text, status, record
with gr.Blocks(title="Sanskrit D3PM Studio") as demo:
model_state = gr.State(None)
gr.Markdown(
"""
# Sanskrit D3PM Studio
Load any available checkpoint, generate Devanagari output from Roman/IAST Sanskrit,
and inspect the settings used for evaluation or demos.
"""
)
with gr.Row():
with gr.Column(scale=2):
checkpoint_dropdown = gr.Dropdown(
label="Available Checkpoints",
choices=list(checkpoint_map().keys()),
value=default_checkpoint_label(),
interactive=True,
)
with gr.Column(scale=1):
refresh_btn = gr.Button("Refresh List")
load_btn = gr.Button("Load Model", variant="primary")
load_status = gr.Markdown("Select a checkpoint and load it.")
model_info = gr.JSON(label="Loaded Model Info")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Input Text (Roman / IAST Sanskrit)",
placeholder="dharmo rakṣati rakṣitaḥ",
lines=4,
)
output_text = gr.Textbox(
label="Generated Output (Devanagari)",
lines=6,
interactive=False,
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
preset = gr.Radio(
["Manual", "Literal", "Balanced", "Creative"],
value="Balanced",
label="Inference Preset",
)
temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
clean_output = gr.Checkbox(value=True, label="Clean Output (dedupe loops)")
run_status = gr.Markdown("")
run_record = gr.JSON(label="Last Generation Metadata")
with gr.Accordion("Task Details and Evaluation Notes", open=False):
task_notes = gr.Markdown(task_notes_md())
gr.Examples(
examples=[
["dharmo rakṣati rakṣitaḥ"],
["satyameva jayate"],
["ahaṃ brahmāsmi"],
["yatra nāryastu pūjyante"],
],
inputs=[input_text],
label="Quick Examples",
)
def refresh_checkpoints():
choices = list(checkpoint_map().keys())
value = choices[0] if choices else None
return gr.Dropdown(choices=choices, value=value)
refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown])
load_btn.click(
fn=load_selected_model,
inputs=[checkpoint_dropdown],
outputs=[model_state, load_status, model_info, num_steps],
)
preset.change(
fn=apply_preset,
inputs=[preset],
outputs=[temperature, top_k, repetition_penalty, diversity_penalty, num_steps],
)
generate_btn.click(
fn=generate_from_ui,
inputs=[
model_state,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
],
outputs=[output_text, run_status, run_record],
)
input_text.submit(
fn=generate_from_ui,
inputs=[
model_state,
input_text,
temperature,
top_k,
repetition_penalty,
diversity_penalty,
num_steps,
clean_output,
],
outputs=[output_text, run_status, run_record],
)
if __name__ == "__main__":
port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
demo.launch(server_name="127.0.0.1", server_port=port, share=False)