aLLoyM / app.py
Playingyoyo's picture
Change diagram stragety
5702775 verified
import gradio as gr
from mcp.server.fastmcp import FastMCP
import theme
import css
import model
import config
import os
import json
import re
import time
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from collections import defaultdict
model.load_model()
# --- CONSTANTS ---
MAX_PHASES = 10
# Get element names from config
ELEMENT_CHOICES = sorted(list(set(config.element_dict.values())))
PHASE_CHOICES = sorted(config.phase_list)
# Example chat to show initially
EXAMPLE_CHAT = [
{"role": "user", "content": "What phases form when Gold (50%) + Silver (50%) are mixed at 800 K?"},
{"role": "assistant", "content": "FCC_A1"}
]
# --- Animated dots generator ---
def get_dots(step):
dots = [".", "..", "..."]
return dots[step % 3]
# --- PHASE NAME UTILITIES ---
def sort_phase_name(phase_name):
if not phase_name:
return phase_name
components = phase_name.split(' + ')
sorted_components = sorted(components)
return ' + '.join(sorted_components)
def normalize_phase_name(phase_text):
if not phase_text:
return phase_text
separators = [' and ', ' + ', ', ', ' & ', ' with ']
phase_parts = []
for sep in separators:
if sep in phase_text:
phase_parts = [part.strip() for part in phase_text.split(sep) if part.strip()]
break
if len(phase_parts) > 1:
phase_parts = [part for part in phase_parts if part]
return " + ".join(sorted(phase_parts))
return sort_phase_name(phase_text)
def clean_phase_name(phase_name):
if not phase_name:
return phase_name
cleaned = re.sub(r'[^\w\s\-+()]', '', phase_name)
cleaned = re.sub(r'\s+', ' ', cleaned)
return cleaned.strip()
def extract_phases(answer):
if not answer or answer.startswith("Error:") or answer.startswith("❌"):
return []
answer = answer.strip()
if answer.endswith('.'):
answer = answer[:-1]
separators = [' and ', ' + ', ', ', ' & ', ' with ', '\n']
phases = [answer]
for sep in separators:
new_phases = []
for phase in phases:
new_phases.extend(phase.split(sep))
phases = new_phases
phases = [clean_phase_name(p) for p in phases if p.strip()]
phases = [p for p in phases if p and len(p) > 0]
if len(phases) > 1:
return [" + ".join(sorted(phases))]
return phases
# --- DATA LOADING ---
def parse_question(question):
element_pattern = r'([A-Za-z]+)\s*\(\s*(\d+(?:\.\d+)?)\s*%\s*\)'
elements = re.findall(element_pattern, question)
if not elements:
return None
temp_match = re.search(r'(\d+)\s*K', question)
if not temp_match:
return None
temperature = float(temp_match.group(1))
element_names = [el for el, _ in elements]
percentages = {el: float(pct) for el, pct in elements}
return {
'elements': element_names,
'percentages': percentages,
'temperature': temperature
}
def load_jsonl_data(filepath):
if not os.path.exists(filepath):
return []
entries = []
with open(filepath, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
question = data.get("user", "")
answer = data.get("answer", "")
parsed = parse_question(question)
if not parsed:
continue
phases = extract_phases(answer)
if not phases:
phases = ["Unknown"]
phases = [normalize_phase_name(p) for p in phases]
entries.append({
'elements': parsed['elements'],
'percentages': parsed['percentages'],
'temperature': parsed['temperature'],
'phases': phases,
})
return entries
# --- BINARY PHASE DIAGRAM ---
def create_binary_phase_diagram(entries, output_base_path):
if not entries:
return None, None
all_elements = set()
for entry in entries:
all_elements.update(entry['elements'])
elements = sorted(all_elements)
if len(elements) == 1:
x_label = f"Composition (% {elements[0]})"
elif len(elements) == 2:
x_label = f"Composition (% {elements[1]})"
else:
x_label = "Composition (%)"
all_phases = set()
for entry in entries:
all_phases.update(entry['phases'])
all_phases = sorted(all_phases)
if len(all_phases) <= 1:
colors = {list(all_phases)[0]: plt.cm.tab10(0)} if all_phases else {}
else:
colors = {phase: plt.cm.tab10(i % 10) for i, phase in enumerate(all_phases)}
phase_points = defaultdict(list)
for entry in entries:
if len(elements) == 1:
x = 100.0
elif len(elements) == 2:
x = entry['percentages'].get(elements[1], 0.0)
else:
x = entry['percentages'].get(elements[-1], 0.0)
y = entry['temperature']
for phase in entry['phases']:
phase_points[phase].append((x, y))
fig, ax = plt.subplots(figsize=(12, 10))
handles = []
labels = []
for phase in sorted(phase_points.keys()):
points = phase_points[phase]
if not points:
continue
x_vals, y_vals = zip(*points)
scatter = ax.scatter(x_vals, y_vals,
c=[colors[phase]],
s=250,
alpha=0.8,
edgecolors='white',
linewidths=2)
handles.append(scatter)
labels.append(phase)
if len(elements) == 2:
ax.set_title(f"{elements[0]}-{elements[1]} Phase Diagram",
fontsize=36, fontweight='bold', pad=20)
elif len(elements) == 1:
ax.set_title(f"{elements[0]} Phase Diagram",
fontsize=36, fontweight='bold', pad=20)
else:
ax.set_title("Phase Diagram", fontsize=36, fontweight='bold', pad=20)
ax.set_xlabel(x_label, fontsize=28)
ax.set_ylabel("Temperature (K)", fontsize=28)
all_temps = [entry['temperature'] for entry in entries]
min_temp = min(all_temps) - 50
max_temp = max(all_temps) + 50
ax.set_xlim(-5, 105)
ax.set_ylim(min_temp, max_temp)
ax.tick_params(axis='both', labelsize=24)
if handles:
ax.legend(handles, labels,
title="Phases",
title_fontsize=24,
fontsize=20,
loc='upper center',
bbox_to_anchor=(0.5, -0.15),
ncol=min(4, len(handles)),
frameon=True)
plt.tight_layout()
output_dir = os.path.dirname(output_base_path) if os.path.dirname(output_base_path) else "tmp"
os.makedirs(output_dir, exist_ok=True)
base_name = os.path.splitext(output_base_path)[0]
png_path = f"{base_name}.png"
svg_path = f"{base_name}.svg"
plt.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.savefig(svg_path, format='svg', bbox_inches='tight')
plt.close()
return png_path, svg_path
# --- TERNARY PHASE DIAGRAM ---
def get_ternary_coords(percentages, elements):
comp_top = percentages.get(elements[0], 0.0)
comp_right = percentages.get(elements[2], 0.0)
y = np.sqrt(3) * comp_top / 2 / 100.0
x = (comp_top / 2 + comp_right) / 100.0
return x, y
def setup_ternary_axes(ax, elements):
h = np.sqrt(3.0) * 0.5
ax.set_aspect('equal', 'datalim')
ax.axis('off')
ax.plot([0.0, 1.0], [0.0, 0.0], 'k-', lw=2)
ax.plot([0.0, 0.5], [0.0, h], 'k-', lw=2)
ax.plot([1.0, 0.5], [0.0, h], 'k-', lw=2)
for i in range(1, 10):
alpha = 0.3
ax.plot([i / 20.0, 1.0 - i / 20.0], [h * i / 10.0, h * i / 10.0],
color='lightgray', lw=0.5, zorder=0, alpha=alpha)
ax.plot([i / 20.0, i / 10.0], [h * i / 10.0, 0.0],
color='lightgray', lw=0.5, zorder=0, alpha=alpha)
ax.plot([0.5 + i / 20.0, i / 10.0], [h * (1.0 - i / 10.0), 0.0],
color='lightgray', lw=0.5, zorder=0, alpha=alpha)
ax.text(0.5, h + 0.05, str(elements[0]), fontsize=48, ha='center', va='bottom', fontweight='bold')
ax.text(-0.05, -0.05, str(elements[1]), fontsize=48, ha='right', va='top', fontweight='bold')
ax.text(1.05, -0.05, str(elements[2]), fontsize=48, ha='left', va='top', fontweight='bold')
def create_ternary_phase_diagram(entries, output_base_path, temperature):
if not entries:
return None, None
all_elements = set()
for entry in entries:
all_elements.update(entry['elements'])
elements = sorted(all_elements)
if len(elements) != 3:
return None, None
all_phases = set()
for entry in entries:
all_phases.update(entry['phases'])
all_phases = sorted(all_phases)
if len(all_phases) <= 1:
phase_to_color = {list(all_phases)[0]: cm.rainbow(0.5)} if all_phases else {}
else:
phase_to_color = {
phase: cm.rainbow(float(i) / (len(all_phases) - 1))
for i, phase in enumerate(all_phases)
}
phase_points = defaultdict(list)
for entry in entries:
x, y = get_ternary_coords(entry['percentages'], elements)
for phase in entry['phases']:
phase_points[phase].append((x, y))
fig, ax = plt.subplots(figsize=(14, 14))
setup_ternary_axes(ax, elements)
system_name = "-".join(elements)
ax.set_title(f"{system_name} Phase Diagram @ {int(temperature)}K", fontsize=40, fontweight='bold', pad=30)
handles = []
labels = []
for phase in sorted(phase_points.keys()):
points = phase_points[phase]
if not points:
continue
x_vals, y_vals = zip(*points)
scatter = ax.scatter(x_vals, y_vals,
c=[phase_to_color[phase]],
s=200,
alpha=0.8,
edgecolors='white',
linewidths=1.5,
zorder=10)
handles.append(scatter)
labels.append(phase)
if handles:
ax.legend(handles, labels,
title="Phases",
title_fontsize=28,
fontsize=22,
loc='upper center',
bbox_to_anchor=(0.5, -0.05),
ncol=min(3, len(handles)),
frameon=True)
plt.tight_layout()
output_dir = os.path.dirname(output_base_path) if os.path.dirname(output_base_path) else "tmp"
os.makedirs(output_dir, exist_ok=True)
base_name = os.path.splitext(output_base_path)[0]
png_path = f"{base_name}.png"
svg_path = f"{base_name}.svg"
plt.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.savefig(svg_path, format='svg', bbox_inches='tight')
plt.close()
return png_path, svg_path
def generate_phase_diagram_from_jsonl(jsonl_path, output_base_path, is_ternary=False, temperature=None):
entries = load_jsonl_data(jsonl_path)
if not entries:
return None, None
if is_ternary:
return create_ternary_phase_diagram(entries, output_base_path, temperature)
else:
return create_binary_phase_diagram(entries, output_base_path)
# --- TERNARY COMPOSITION GRID ---
def generate_ternary_compositions(step=20):
compositions = []
for c1 in range(0, 101, step):
for c2 in range(0, 101 - c1, step):
c3 = 100 - c1 - c2
compositions.append((c1, c2, c3))
return compositions
# --- EXPORT CHAT HISTORY ---
def export_chat_history(history):
"""Export chat history to JSONL file."""
if not history:
return None
# Filter out example chat
real_history = [msg for msg in history if msg not in EXAMPLE_CHAT]
if not real_history:
return None
os.makedirs("tmp", exist_ok=True)
output_path = "tmp/chat_history.jsonl"
with open(output_path, "w") as f:
current_user = None
for msg in real_history:
if isinstance(msg, dict):
role = msg.get("role", "")
content = msg.get("content", "")
# Skip image messages
if isinstance(content, dict):
continue
# Extract text from list format if needed
if isinstance(content, list):
texts = []
for item in content:
if isinstance(item, dict) and "text" in item:
texts.append(item["text"])
elif isinstance(item, str):
texts.append(item)
content = " ".join(texts)
# Skip loading dots
if content in [".", "..", "..."]:
continue
if role == "user":
current_user = content
elif role == "assistant" and current_user:
f.write(json.dumps({"user": current_user, "assistant": content}) + "\n")
current_user = None
return output_path
# --- INFERENCE HELPER (shared by UI and API) ---
def _run_single_inference(question, system_type="binary"):
"""Run a single inference and return the text answer."""
system_instruction = (
f"You are an expert in phase diagrams, thermodynamics, and materials science, "
f"specializing in {system_type} alloy systems."
)
response = model.run_inference(question, system_instruction, [])
if not response:
return "No response received."
last_msg = response[-1]
if isinstance(last_msg, dict):
return last_msg.get('content', str(last_msg))
if isinstance(last_msg, (list, tuple)):
return str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0])
return str(last_msg)
def _run_batch_inference(questions, system_type="binary"):
"""Run multiple inferences back-to-back without delays.
Returns a list of answer strings, one per question.
"""
system_instruction = (
f"You are an expert in phase diagrams, thermodynamics, and materials science, "
f"specializing in {system_type} alloy systems."
)
answers = []
for question in questions:
try:
response = model.run_inference(question, system_instruction, [])
if not response:
answers.append("No response received.")
continue
last_msg = response[-1]
if isinstance(last_msg, dict):
answers.append(last_msg.get('content', str(last_msg)))
elif isinstance(last_msg, (list, tuple)):
answers.append(str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0]))
else:
answers.append(str(last_msg))
except Exception as e:
answers.append(f"Error: {e}")
return answers
# --- LOGIC FUNCTIONS ---
def generate_prompt_text(task_name, e1, e2, e3, p1, p2, p3, temp, temp_min, temp_max, ternary_temp, *phases):
target_phases = [p.strip() for p in phases if p and str(p).strip()]
elements = [(e1, p1), (e2, p2), (e3, p3)]
active_elements = [x for x in elements if x[0] and len(str(x[0]).strip()) > 0]
if not active_elements:
return "Please define elements in Step 1."
count = len(active_elements)
if task_name == "Phase name prediction":
mixture_parts = []
for elem, perc in active_elements:
p_val = perc if perc is not None else 0
mixture_parts.append(f"{elem} ({p_val}%)")
mixture_str = " + ".join(mixture_parts)
temp_val = temp if temp is not None else "___"
return f"What phases form when {mixture_str} are mixed at {temp_val} K?"
elif task_name == "Experimental condition prediction":
element_names = [x[0] for x in active_elements]
elements_str = " + ".join(element_names)
ph_val = " + ".join(target_phases) if target_phases else "___"
return f"Under what condition do {elements_str} form {ph_val}?"
elif task_name == "Phase diagram prediction":
element_names = [x[0] for x in active_elements]
sys_str = "-".join(element_names)
if count == 3:
t_val = int(ternary_temp) if ternary_temp is not None else "___"
return f"Draw the phase diagram for the {sys_str} system at {t_val} K."
else:
t_min = int(temp_min) if temp_min is not None else "___"
t_max = int(temp_max) if temp_max is not None else "___"
return f"Draw the phase diagram for the {sys_str} system in the temperature range {t_min}-{t_max} K."
return "Please select a task in Step 2."
def run_chat(prompt_text, e1, e2, e3, p1, p2, p3, history, task_name, temp_min, temp_max, ternary_temp, svg_state, is_running):
"""Main chat function. Returns (history, svg_path, is_running, btn_update)."""
elements = [e1, e2, e3]
active_elements = [x for x in elements if x and len(str(x).strip()) > 0]
count = len(active_elements)
# Clear example chat on first real input
if history == EXAMPLE_CHAT or history is None:
history = []
else:
history = list(history)
current_svg = svg_state
# --- PHASE DIAGRAM BATCH LOGIC ---
if task_name == "Phase diagram prediction" and count >= 1:
os.makedirs("tmp", exist_ok=True)
q_file_path = "tmp/q.jsonl"
a_file_path = "tmp/a.jsonl"
is_ternary = (count == 3)
if is_ternary:
t_single = int(ternary_temp) if ternary_temp else 800
t_start, t_end = t_single, t_single
elements_str = "-".join(active_elements)
diagram_base = f"tmp/{elements_str}_{t_single}K"
else:
t_start = int(temp_min) if temp_min else 300
t_end = int(temp_max) if temp_max else 1000
if t_start > t_end:
t_start, t_end = t_end, t_start
elements_str = "-".join(active_elements)
diagram_base = f"tmp/{elements_str}_{t_start}-{t_end}K"
system_type = "ternary" if is_ternary else ("binary" if count == 2 else "elemental")
system_instruction = (
f"You are an expert in phase diagrams, thermodynamics, and materials science, "
f"specializing in {system_type} alloy systems."
)
generated_count = 0
error_count = 0
max_errors = 5
dot_step = 0
if is_ternary:
history.append({"role": "assistant", "content": f"Starting {system_type} batch for **{elements_str}** at {t_single}K..."})
else:
history.append({"role": "assistant", "content": f"Starting {system_type} batch for **{elements_str}** ({t_start}K - {t_end}K)..."})
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
with open(q_file_path, "w") as q_file, open(a_file_path, "w") as a_file:
for t in range(t_start, t_end + 1, 50):
if is_ternary:
compositions = generate_ternary_compositions(step=20)
elif count == 2:
compositions = [(100 - i, i) for i in range(0, 101, 20)]
else:
compositions = [(100,)]
for comp in compositions:
if error_count >= max_errors:
history.append({"role": "assistant", "content": f"Stopped (errors). Completed: {generated_count}"})
yield history, current_svg, False, gr.update(value="β–Ά", elem_classes=[])
return
try:
if is_ternary:
c1, c2, c3 = comp
mix_str = f"{active_elements[0]} ({c1}%) + {active_elements[1]} ({c2}%) + {active_elements[2]} ({c3}%)"
elif count == 2:
c1, c2 = comp
if c2 == 0:
mix_str = f"{active_elements[0]} (100%)"
elif c2 == 100:
mix_str = f"{active_elements[1]} (100%)"
else:
mix_str = f"{active_elements[0]} ({c1}%) + {active_elements[1]} ({c2}%)"
else:
mix_str = f"{active_elements[0]} (100%)"
question = f"What phases form when {mix_str} are mixed at {t} K? Answer phase names only."
history.append({"role": "user", "content": question})
history.append({"role": "assistant", "content": get_dots(dot_step)})
dot_step += 1
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
q_file.write(json.dumps({"user": question}) + "\n")
q_file.flush()
time.sleep(1.5)
answer = "Error"
for attempt in range(3):
try:
history[-1] = {"role": "assistant", "content": get_dots(dot_step)}
dot_step += 1
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
temp_response = model.run_inference(question, system_instruction, [])
if temp_response:
last_msg = temp_response[-1]
if isinstance(last_msg, dict):
answer = last_msg.get('content', str(last_msg))
elif isinstance(last_msg, (list, tuple)):
answer = str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0])
else:
answer = str(last_msg)
error_count = 0
break
except Exception as e:
if attempt < 2:
history[-1] = {"role": "assistant", "content": f"Retry {attempt + 2}{get_dots(dot_step)}"}
dot_step += 1
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
time.sleep(3 * (attempt + 1))
else:
answer = f"Error: {str(e)[:50]}"
error_count += 1
history[-1] = {"role": "assistant", "content": answer}
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
a_file.write(json.dumps({"user": question, "answer": answer}) + "\n")
a_file.flush()
generated_count += 1
except Exception as e:
error_count += 1
if history and history[-1].get("content", "").startswith("."):
history[-1] = {"role": "assistant", "content": f"Error: {str(e)[:50]}"}
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
time.sleep(1)
if error_count >= max_errors:
break
history.append({"role": "assistant", "content": f"Generating {system_type} diagram{get_dots(dot_step)}"})
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
try:
png_path, svg_path = generate_phase_diagram_from_jsonl(
a_file_path, diagram_base,
is_ternary=is_ternary,
temperature=t_single if is_ternary else None
)
if png_path and os.path.exists(png_path):
current_svg = svg_path
history[-1] = {"role": "assistant", "content": (
f"**Complete!** {generated_count} points generated.\n\n"
f"System: {elements_str} ({system_type})"
)}
history.append({
"role": "assistant",
"content": {"path": png_path, "alt_text": f"{elements_str} Phase Diagram"}
})
yield history, current_svg, False, gr.update(value="β–Ά", elem_classes=[])
else:
history[-1] = {"role": "assistant", "content": f"Complete ({generated_count} points). Diagram failed."}
yield history, current_svg, False, gr.update(value="β–Ά", elem_classes=[])
except Exception as e:
print(f"Diagram error: {e}")
history[-1] = {"role": "assistant", "content": f"Complete. Diagram error: {str(e)[:50]}"}
yield history, current_svg, False, gr.update(value="β–Ά", elem_classes=[])
# --- STANDARD CHAT ---
else:
system_type = "ternary" if count == 3 else "binary" if count == 2 else "elemental"
system_instruction = (
f"You are an expert in phase diagrams, thermodynamics, and materials science, "
f"specializing in {system_type} alloy systems."
)
history.append({"role": "user", "content": prompt_text})
dot_step = 0
history.append({"role": "assistant", "content": get_dots(dot_step)})
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
for i in range(2):
time.sleep(0.3)
dot_step += 1
history[-1] = {"role": "assistant", "content": get_dots(dot_step)}
yield history, current_svg, True, gr.update(value="β– ", elem_classes=["stop-btn"])
try:
temp_response = model.run_inference(prompt_text, system_instruction, [])
if temp_response:
last_msg = temp_response[-1]
if isinstance(last_msg, dict):
answer = last_msg.get('content', str(last_msg))
elif isinstance(last_msg, (list, tuple)):
answer = str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0])
else:
answer = str(last_msg)
else:
answer = "No response received."
history[-1] = {"role": "assistant", "content": answer}
yield history, current_svg, False, gr.update(value="β–Ά", elem_classes=[])
except Exception as e:
history[-1] = {"role": "assistant", "content": f"Error: {str(e)[:100]}"}
yield history, current_svg, False, gr.update(value="β–Ά", elem_classes=[])
# --- UI Helpers ---
def add_next_phase(count):
new_count = min(count + 1, MAX_PHASES)
updates = [gr.update(visible=(i < new_count)) for i in range(MAX_PHASES)]
return [new_count] + updates
def update_perc_visibility(e2_text, e3_text):
"""Update percentage field visibility and labels based on elements."""
has_e2 = e2_text and len(str(e2_text).strip()) > 0
has_e3 = e3_text and len(str(e3_text).strip()) > 0
if has_e3:
# Ternary: perc1 and perc2 editable, perc3 auto
e2_label = f"{str(e2_text).upper()} %" if has_e2 else "ELEMENT 2 %"
e3_label = f"{str(e3_text).upper()} % (AUTO)"
return (gr.update(interactive=True),
gr.update(visible=True, interactive=True, label=e2_label),
gr.update(visible=True, interactive=False, label=e3_label))
elif has_e2:
# Binary: perc1 editable, perc2 auto
e2_label = f"{str(e2_text).upper()} % (AUTO)"
return (gr.update(interactive=True),
gr.update(visible=True, interactive=False, label=e2_label),
gr.update(visible=False))
else:
# Single element: only perc1 visible
return (gr.update(interactive=True),
gr.update(visible=False),
gr.update(visible=False))
def auto_calculate_balance(p1, p2, e2_text, e3_text):
p1, p2 = p1 or 0, p2 or 0
has_e2 = e2_text and len(str(e2_text).strip()) > 0
has_e3 = e3_text and len(str(e3_text).strip()) > 0
if has_e3:
# Ternary: auto-calculate perc3
return gr.update(), gr.update(value=max(0, 100 - p1 - p2))
elif has_e2:
# Binary: auto-calculate perc2
return gr.update(value=max(0, 100 - p1)), gr.update()
return gr.update(), gr.update()
def update_label(element_name):
if element_name and len(str(element_name).strip()) > 0:
return gr.update(label=f"{str(element_name).upper()} %")
return gr.update(label="ELEMENT 1 %")
def update_elem2_enabled(e1):
"""Enable elem2 when elem1 has a value, disable and clear when elem1 is empty."""
has_e1 = e1 and len(str(e1).strip()) > 0
if has_e1:
return gr.update(interactive=True)
else:
return gr.update(interactive=False, value=None)
def update_elem3_enabled(e2):
"""Enable elem3 when elem2 has a value, disable and clear when elem2 is empty."""
has_e2 = e2 and len(str(e2).strip()) > 0
if has_e2:
return gr.update(interactive=True)
else:
return gr.update(interactive=False, value=None)
def update_diag_temp_visibility(e1, e2, e3):
elements = [e1, e2, e3]
active = [x for x in elements if x and len(str(x).strip()) > 0]
count = len(active)
if count == 3:
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=True), gr.update(visible=False)
def activate_step2(e1, e2, e3):
"""Activate STEP 02 when at least one element is entered."""
elements = [e1, e2, e3]
has_element = any(x and len(str(x).strip()) > 0 for x in elements)
if has_element:
return gr.update(elem_classes=["step-frame", "step-active"])
return gr.update(elem_classes=["step-frame", "step-inactive"])
def check_can_generate(temp, perc1, temp_min, temp_max, ternary_temp, task_name, e1, e2, e3, *phases):
is_ready = False
valid_phases = [p for p in phases if p and str(p).strip()]
elements = [e1, e2, e3]
active = [x for x in elements if x and len(str(x).strip()) > 0]
count = len(active)
if count == 0:
return (gr.update(interactive=False),
gr.update(elem_classes=["step-frame", "step-inactive"]))
if task_name == "Phase name prediction":
is_ready = temp is not None and perc1 is not None
elif task_name == "Experimental condition prediction":
is_ready = len(valid_phases) > 0
elif task_name == "Phase diagram prediction":
if count == 3:
is_ready = ternary_temp is not None
else:
is_ready = temp_min is not None and temp_max is not None
return (gr.update(interactive=is_ready),
gr.update(elem_classes=["step-frame", "step-active"] if is_ready else ["step-frame", "step-inactive"]))
def select_task_phase():
return ("Phase name prediction", gr.update(visible=True), gr.update(visible=False), gr.update(visible=False),
gr.update(variant="primary"), gr.update(variant="secondary"), gr.update(variant="secondary"),
gr.update(), gr.update())
def select_task_cond():
return ("Experimental condition prediction", gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),
gr.update(variant="secondary"), gr.update(variant="primary"), gr.update(variant="secondary"),
gr.update(), gr.update())
def select_task_diag(e1, e2, e3):
elements = [e1, e2, e3]
active = [x for x in elements if x and len(str(x).strip()) > 0]
count = len(active)
if count == 3:
binary_vis, ternary_vis = False, True
else:
binary_vis, ternary_vis = True, False
return ("Phase diagram prediction", gr.update(visible=False), gr.update(visible=False), gr.update(visible=True),
gr.update(variant="secondary"), gr.update(variant="secondary"), gr.update(variant="primary"),
gr.update(visible=binary_vis), gr.update(visible=ternary_vis))
# =====================================================================
# MCP API FUNCTIONS β€” clean endpoints for MCP tool exposure
# =====================================================================
def predict_phases_api(element1: str, element2: str = "", element3: str = "",
percent1: float = 100, percent2: float = 0, percent3: float = 0,
temperature_k: float = 800) -> str:
"""Predict which phases form for an alloy at a given composition and temperature.
Provide 1-3 elements with their atomic percentages (must sum to 100) and a
temperature in Kelvin. Returns the predicted equilibrium phase name(s).
To build a phase diagram, call this tool many times (e.g. 50+ calls) across a
grid of compositions and temperatures, then pass all results to the
plot_phase_diagram tool to render the diagram.
For a binary system (2 elements), vary element2 percent from 0 to 100 in steps
of ~10-20, and temperature from min to max in steps of ~50K.
For a ternary system (3 elements), fix temperature and vary compositions across
the triangle in steps of ~20%.
Example: element1="Gold", percent1=50, element2="Silver", percent2=50, temperature_k=800
"""
elements = [(element1, percent1)]
if element2 and str(element2).strip():
elements.append((element2, percent2))
if element3 and str(element3).strip():
elements.append((element3, percent3))
total = sum(p for _, p in elements)
if abs(total - 100) > 0.5:
return f"Error: Percentages sum to {total}, expected 100."
mix_parts = [f"{el} ({p}%)" for el, p in elements]
question = f"What phases form when {' + '.join(mix_parts)} are mixed at {int(temperature_k)} K?"
sys_type = "ternary" if len(elements) == 3 else "binary" if len(elements) == 2 else "elemental"
return _run_single_inference(question, sys_type)
def predict_conditions_api(element1: str, element2: str = "", element3: str = "",
target_phases: str = "") -> str:
"""Predict experimental conditions for given elements to form target phases.
Provide 1-3 elements and one or more target phase names (comma- or +-separated).
Returns predicted conditions (composition, temperature, etc.).
Example: element1="Gold", element2="Silver", target_phases="FCC_A1"
"""
elems = [element1]
if element2 and str(element2).strip():
elems.append(element2)
if element3 and str(element3).strip():
elems.append(element3)
if not target_phases or not str(target_phases).strip():
return "Error: No target phases provided."
phases_clean = [p.strip() for p in re.split(r'[,+]', target_phases) if p.strip()]
elements_str = " + ".join(elems)
phases_str = " + ".join(phases_clean)
question = f"Under what condition do {elements_str} form {phases_str}?"
sys_type = "ternary" if len(elems) == 3 else "binary" if len(elems) == 2 else "elemental"
return _run_single_inference(question, sys_type)
def plot_phase_diagram_api(data_points_json: str) -> str:
"""Generate a phase diagram image from pre-collected data points.
This tool renders a phase diagram from data you have already collected by
calling predict_phases multiple times. It does NOT run any model inference
itself β€” it only plots.
WORKFLOW to build a phase diagram:
1. Call predict_phases ~50+ times across a grid of compositions/temperatures.
2. Collect all results into the JSON format below.
3. Call this tool ONCE with that JSON to render the diagram.
Input: a JSON string with this structure:
{
"elements": ["Gold", "Silver"],
"temperature_k": 800, // (optional, used as title for ternary)
"points": [
{
"composition": {"Gold": 80, "Silver": 20},
"temperature_k": 800,
"phases": "FCC_A1"
},
{
"composition": {"Gold": 50, "Silver": 50},
"temperature_k": 1000,
"phases": "FCC_A1 + LIQUID"
}
]
}
- "elements": list of 2 or 3 element names (must match keys in composition dicts)
- "points": list of data points, each with composition (dict of element->percent),
temperature_k, and phases (string, use " + " to separate multi-phase regions)
- For ternary systems, include "temperature_k" at top level for the diagram title
Returns: a JSON string with the file paths of the generated PNG and SVG images,
or an error message.
"""
try:
data = json.loads(data_points_json)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid JSON: {e}"})
elements = data.get("elements", [])
points = data.get("points", [])
if not elements or len(elements) < 2:
return json.dumps({"error": "Provide at least 2 element names in 'elements' list."})
if not points:
return json.dumps({"error": "No data points provided in 'points' list."})
# Convert to internal entry format
entries = []
for pt in points:
comp = pt.get("composition", {})
temp = pt.get("temperature_k", 800)
phase_str = pt.get("phases", "")
if not phase_str or not str(phase_str).strip():
continue
# Parse phases from the string
phases = extract_phases(str(phase_str))
if not phases:
phases = [normalize_phase_name(str(phase_str))]
phases = [normalize_phase_name(p) for p in phases]
# Build percentages dict β€” ensure all elements are present
percentages = {}
for el in elements:
percentages[el] = float(comp.get(el, 0))
entries.append({
'elements': [el for el in elements if percentages.get(el, 0) > 0] or [elements[0]],
'percentages': percentages,
'temperature': float(temp),
'phases': phases,
})
if not entries:
return json.dumps({"error": "No valid data points after parsing."})
# Determine diagram type
is_ternary = len(elements) == 3
elements_str = "-".join(sorted(elements))
os.makedirs("tmp", exist_ok=True)
if is_ternary:
temp_label = int(data.get("temperature_k", entries[0]['temperature']))
diagram_base = f"tmp/{elements_str}_{temp_label}K"
png_path, svg_path = create_ternary_phase_diagram(entries, diagram_base, temp_label)
else:
temps = [e['temperature'] for e in entries]
t_min, t_max = int(min(temps)), int(max(temps))
diagram_base = f"tmp/{elements_str}_{t_min}-{t_max}K"
png_path, svg_path = create_binary_phase_diagram(entries, diagram_base)
if png_path and os.path.exists(png_path):
result = {
"status": "success",
"system": elements_str,
"type": "ternary" if is_ternary else "binary",
"num_points": len(entries),
"png_path": png_path,
"svg_path": svg_path,
}
return json.dumps(result)
else:
return json.dumps({"error": "Diagram generation failed."})
# --- UI CONSTRUCTION ---
with gr.Blocks() as demo:
# State
svg_state = gr.State(None)
is_running = gr.State(False)
with gr.Column(elem_classes=["title-area"]):
gr.Markdown("# aLLoyM")
gr.Markdown("Interactive Alloy Design & Phase Diagram Assistant", elem_classes=["subtitle"])
gr.Markdown(
"Ask questions to aLLoyM, our Mistral-based model fine-tuned on Computational Phase Diagram Database (CPDDB) "
"and assessments based on CALPHAD (CALculation of PHAse Diagrams). "
"Check out our [**paper**](https://www.nature.com/articles/s41524-026-01966-6) for more information. "
"You can also download the [**weights**](https://huggingface.co/Playingyoyo/aLLoyM) and fine-tune, "
"or download the [**data**](https://huggingface.co/datasets/Playingyoyo/aLLoyM-dataset). "
"Like all AI models, aLLoyM may hallucinate outside its training scopeβ€”but this also means you can explore alloy compositions that have never been experimentally tested.",
elem_classes=["description"]
)
with gr.Row(elem_classes=["main-row"]):
with gr.Column(scale=4):
with gr.Column(elem_classes=["step-frame", "step-active"]) as step1_frame:
gr.Markdown("STEP 01 // CHOOSE ELEMENTS", elem_classes=["step-label"])
gr.Markdown("*Type, or choose from below for more reliable answers*", elem_classes=["step-hint"])
with gr.Column():
elem1 = gr.Dropdown(label="ELEMENT 1", choices=ELEMENT_CHOICES, value=None, allow_custom_value=True)
elem2 = gr.Dropdown(label="ELEMENT 2 (optional)", choices=ELEMENT_CHOICES, value=None, allow_custom_value=True, interactive=False)
elem3 = gr.Dropdown(label="ELEMENT 3 (optional)", choices=ELEMENT_CHOICES, value=None, allow_custom_value=True, interactive=False)
with gr.Column(elem_classes=["step-frame", "step-inactive"]) as step2_frame:
gr.Markdown("STEP 02 // CHOOSE TASK AND INSERT DETAILS", elem_classes=["step-label"])
current_task = gr.State("Phase name prediction")
with gr.Column():
btn_phase = gr.Button("PHASE PREDICTION", elem_classes=["task-card"], variant="primary")
btn_cond = gr.Button("EXP. CONDITIONS", elem_classes=["task-card"], variant="secondary")
btn_diag = gr.Button("PHASE DIAGRAM", elem_classes=["task-card"], variant="secondary")
with gr.Group(visible=True) as group_phase:
with gr.Column():
temp_input = gr.Number(label="TEMPERATURE (K)", value=800, step=50)
perc1 = gr.Number(label="ELEMENT 1 %", value=100)
perc2 = gr.Number(label="ELEMENT 2 % (AUTO)", value=0, interactive=False, visible=False)
perc3 = gr.Number(label="ELEMENT 3 % (AUTO)", value=0, interactive=False, visible=False)
phase_inputs = []
phase_count = gr.State(1)
with gr.Group(visible=False) as group_cond:
gr.Markdown("*Type, or choose from below for more reliable answers*", elem_classes=["step-hint"])
for i in range(MAX_PHASES):
t = gr.Dropdown(choices=PHASE_CHOICES, value=None, allow_custom_value=True, show_label=False, visible=(i==0), container=False)
phase_inputs.append(t)
add_phase_btn = gr.Button("+", variant="secondary")
with gr.Group(visible=False) as group_diag:
with gr.Column(visible=True) as binary_temp_group:
temp_min = gr.Number(label="MIN TEMPERATURE (K)", value=300, step=50)
temp_max = gr.Number(label="MAX TEMPERATURE (K)", value=1000, step=50)
with gr.Column(visible=False) as ternary_temp_group:
ternary_temp = gr.Number(label="TEMPERATURE (K)", value=800, step=50)
with gr.Column(scale=6):
with gr.Column(elem_classes=["step-frame", "step-inactive"], elem_id="right-panel") as chat_frame:
gr.Markdown("STEP 03 // CHAT", elem_classes=["step-label"])
chatbot = gr.Chatbot(label="", show_label=False, height=500, layout="panel", value=EXAMPLE_CHAT)
with gr.Row():
download_svg_btn = gr.DownloadButton("πŸ“₯ DOWNLOAD SVG", visible=True, scale=1)
download_chat_btn = gr.DownloadButton("πŸ“₯ DOWNLOAD CHAT (JSONL)", visible=True, scale=1)
with gr.Row(elem_classes=["input-area"]):
prompt_preview = gr.Textbox(
label="USER INPUT (AUTO-GENERATED)",
value="Select a task...",
lines=2,
interactive=False,
scale=8
)
send_btn = gr.Button("β–Ά", scale=1, interactive=False, elem_id="send-btn")
# --- Events (all with api_name=False to hide from MCP) ---
for elem in [elem1, elem2, elem3]:
elem.change(fn=activate_step2, inputs=[elem1, elem2, elem3], outputs=step2_frame, show_progress="hidden", api_name=False)
# Enable elem2 when elem1 is filled, enable elem3 when elem2 is filled
elem1.change(fn=update_elem2_enabled, inputs=elem1, outputs=elem2, show_progress="hidden", api_name=False)
elem2.change(fn=update_elem3_enabled, inputs=elem2, outputs=elem3, show_progress="hidden", api_name=False)
elem1.change(update_label, elem1, perc1, show_progress="hidden", api_name=False)
# Update percentage visibility when elem2 or elem3 changes
elem2.change(fn=update_perc_visibility, inputs=[elem2, elem3], outputs=[perc1, perc2, perc3], api_name=False)
elem3.change(fn=update_perc_visibility, inputs=[elem2, elem3], outputs=[perc1, perc2, perc3], api_name=False)
for elem in [elem1, elem2, elem3]:
elem.change(fn=update_diag_temp_visibility,
inputs=[elem1, elem2, elem3],
outputs=[binary_temp_group, ternary_temp_group],
show_progress="hidden", api_name=False)
perc1.change(fn=auto_calculate_balance, inputs=[perc1, perc2, elem2, elem3], outputs=[perc2, perc3], api_name=False)
perc2.change(fn=auto_calculate_balance, inputs=[perc1, perc2, elem2, elem3], outputs=[perc2, perc3], api_name=False)
btn_phase.click(select_task_phase, None, [current_task, group_phase, group_cond, group_diag, btn_phase, btn_cond, btn_diag, binary_temp_group, ternary_temp_group], api_name=False)
btn_cond.click(select_task_cond, None, [current_task, group_phase, group_cond, group_diag, btn_phase, btn_cond, btn_diag, binary_temp_group, ternary_temp_group], api_name=False)
btn_diag.click(select_task_diag, [elem1, elem2, elem3], [current_task, group_phase, group_cond, group_diag, btn_phase, btn_cond, btn_diag, binary_temp_group, ternary_temp_group], api_name=False)
add_phase_btn.click(fn=add_next_phase, inputs=phase_count, outputs=[phase_count] + phase_inputs, api_name=False)
all_prompt_inputs = [current_task, elem1, elem2, elem3, perc1, perc2, perc3, temp_input, temp_min, temp_max, ternary_temp] + phase_inputs
check_inputs = [temp_input, perc1, temp_min, temp_max, ternary_temp, current_task, elem1, elem2, elem3] + phase_inputs
for inp in all_prompt_inputs:
inp.change(fn=generate_prompt_text, inputs=all_prompt_inputs, outputs=prompt_preview, api_name=False)
inp.change(fn=check_can_generate,
inputs=check_inputs,
outputs=[send_btn, chat_frame],
show_progress="hidden", api_name=False)
# Main send button (UI only, not exposed as MCP)
send_btn.click(
fn=run_chat,
inputs=[prompt_preview, elem1, elem2, elem3, perc1, perc2, perc3, chatbot, current_task, temp_min, temp_max, ternary_temp, svg_state, is_running],
outputs=[chatbot, svg_state, is_running, send_btn],
show_progress="hidden",
api_name=False
)
# Update download button when SVG is available
def update_svg_download(svg_path):
if svg_path and os.path.exists(svg_path):
return gr.DownloadButton(value=svg_path, visible=True)
return gr.DownloadButton(value=None, visible=True)
svg_state.change(fn=update_svg_download, inputs=svg_state, outputs=download_svg_btn, api_name=False)
# Update chat download button when chat changes
def update_chat_download(history):
path = export_chat_history(history)
if path and os.path.exists(path):
return gr.DownloadButton(value=path, visible=True)
return gr.DownloadButton(value=None, visible=True)
chatbot.change(fn=update_chat_download, inputs=chatbot, outputs=download_chat_btn, api_name=False)
# =================================================================
# MCP API ENDPOINTS β€” these are the only tools exposed via MCP
# =================================================================
# Hidden components for API inputs/outputs
with gr.Row(visible=False):
api_out = gr.Textbox()
# --- Task 1: Phase Prediction ---
api_e1 = gr.Textbox()
api_e2 = gr.Textbox()
api_e3 = gr.Textbox()
api_p1 = gr.Number()
api_p2 = gr.Number()
api_p3 = gr.Number()
api_temp = gr.Number()
api_btn_phase = gr.Button()
# --- Task 2: Condition Prediction ---
api_cond_e1 = gr.Textbox()
api_cond_e2 = gr.Textbox()
api_cond_e3 = gr.Textbox()
api_cond_phases = gr.Textbox()
api_btn_cond = gr.Button()
# --- Task 3: Plot Phase Diagram (from pre-collected data) ---
api_plot_data = gr.Textbox()
api_btn_plot = gr.Button()
api_out2 = gr.Textbox()
api_out3 = gr.Textbox()
api_btn_phase.click(
fn=predict_phases_api,
inputs=[api_e1, api_e2, api_e3, api_p1, api_p2, api_p3, api_temp],
outputs=api_out,
api_name="predict_phases"
)
api_btn_cond.click(
fn=predict_conditions_api,
inputs=[api_cond_e1, api_cond_e2, api_cond_e3, api_cond_phases],
outputs=api_out2,
api_name="predict_conditions"
)
api_btn_plot.click(
fn=plot_phase_diagram_api,
inputs=[api_plot_data],
outputs=api_out3,
api_name="plot_phase_diagram"
)
if __name__ == "__main__":
demo.launch(theme=theme.get_theme(), css=css.get_css(), mcp_server=True)