Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| from mcp.server.fastmcp import FastMCP | |
| import theme | |
| import css | |
| import model | |
| import config | |
| import os | |
| import json | |
| import re | |
| import time | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| from collections import defaultdict | |
| model.load_model() | |
| # --- CONSTANTS --- | |
| MAX_PHASES = 10 | |
| # Get element names from config | |
| ELEMENT_CHOICES = sorted(list(set(config.element_dict.values()))) | |
| PHASE_CHOICES = sorted(config.phase_list) | |
| # Example chat to show initially | |
| EXAMPLE_CHAT = [ | |
| {"role": "user", "content": "What phases form when Gold (50%) + Silver (50%) are mixed at 800 K?"}, | |
| {"role": "assistant", "content": "FCC_A1"} | |
| ] | |
| # --- Animated dots generator --- | |
| def get_dots(step): | |
| dots = [".", "..", "..."] | |
| return dots[step % 3] | |
| # --- PHASE NAME UTILITIES --- | |
| def sort_phase_name(phase_name): | |
| if not phase_name: | |
| return phase_name | |
| components = phase_name.split(' + ') | |
| sorted_components = sorted(components) | |
| return ' + '.join(sorted_components) | |
| def normalize_phase_name(phase_text): | |
| if not phase_text: | |
| return phase_text | |
| separators = [' and ', ' + ', ', ', ' & ', ' with '] | |
| phase_parts = [] | |
| for sep in separators: | |
| if sep in phase_text: | |
| phase_parts = [part.strip() for part in phase_text.split(sep) if part.strip()] | |
| break | |
| if len(phase_parts) > 1: | |
| phase_parts = [part for part in phase_parts if part] | |
| return " + ".join(sorted(phase_parts)) | |
| return sort_phase_name(phase_text) | |
| def clean_phase_name(phase_name): | |
| if not phase_name: | |
| return phase_name | |
| cleaned = re.sub(r'[^\w\s\-+()]', '', phase_name) | |
| cleaned = re.sub(r'\s+', ' ', cleaned) | |
| return cleaned.strip() | |
| def extract_phases(answer): | |
| if not answer or answer.startswith("Error:") or answer.startswith("β"): | |
| return [] | |
| answer = answer.strip() | |
| if answer.endswith('.'): | |
| answer = answer[:-1] | |
| separators = [' and ', ' + ', ', ', ' & ', ' with ', '\n'] | |
| phases = [answer] | |
| for sep in separators: | |
| new_phases = [] | |
| for phase in phases: | |
| new_phases.extend(phase.split(sep)) | |
| phases = new_phases | |
| phases = [clean_phase_name(p) for p in phases if p.strip()] | |
| phases = [p for p in phases if p and len(p) > 0] | |
| if len(phases) > 1: | |
| return [" + ".join(sorted(phases))] | |
| return phases | |
| # --- DATA LOADING --- | |
| def parse_question(question): | |
| element_pattern = r'([A-Za-z]+)\s*\(\s*(\d+(?:\.\d+)?)\s*%\s*\)' | |
| elements = re.findall(element_pattern, question) | |
| if not elements: | |
| return None | |
| temp_match = re.search(r'(\d+)\s*K', question) | |
| if not temp_match: | |
| return None | |
| temperature = float(temp_match.group(1)) | |
| element_names = [el for el, _ in elements] | |
| percentages = {el: float(pct) for el, pct in elements} | |
| return { | |
| 'elements': element_names, | |
| 'percentages': percentages, | |
| 'temperature': temperature | |
| } | |
| def load_jsonl_data(filepath): | |
| if not os.path.exists(filepath): | |
| return [] | |
| entries = [] | |
| with open(filepath, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| data = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| question = data.get("user", "") | |
| answer = data.get("answer", "") | |
| parsed = parse_question(question) | |
| if not parsed: | |
| continue | |
| phases = extract_phases(answer) | |
| if not phases: | |
| phases = ["Unknown"] | |
| phases = [normalize_phase_name(p) for p in phases] | |
| entries.append({ | |
| 'elements': parsed['elements'], | |
| 'percentages': parsed['percentages'], | |
| 'temperature': parsed['temperature'], | |
| 'phases': phases, | |
| }) | |
| return entries | |
| # --- BINARY PHASE DIAGRAM --- | |
| def create_binary_phase_diagram(entries, output_base_path): | |
| if not entries: | |
| return None, None | |
| all_elements = set() | |
| for entry in entries: | |
| all_elements.update(entry['elements']) | |
| elements = sorted(all_elements) | |
| if len(elements) == 1: | |
| x_label = f"Composition (% {elements[0]})" | |
| elif len(elements) == 2: | |
| x_label = f"Composition (% {elements[1]})" | |
| else: | |
| x_label = "Composition (%)" | |
| all_phases = set() | |
| for entry in entries: | |
| all_phases.update(entry['phases']) | |
| all_phases = sorted(all_phases) | |
| if len(all_phases) <= 1: | |
| colors = {list(all_phases)[0]: plt.cm.tab10(0)} if all_phases else {} | |
| else: | |
| colors = {phase: plt.cm.tab10(i % 10) for i, phase in enumerate(all_phases)} | |
| phase_points = defaultdict(list) | |
| for entry in entries: | |
| if len(elements) == 1: | |
| x = 100.0 | |
| elif len(elements) == 2: | |
| x = entry['percentages'].get(elements[1], 0.0) | |
| else: | |
| x = entry['percentages'].get(elements[-1], 0.0) | |
| y = entry['temperature'] | |
| for phase in entry['phases']: | |
| phase_points[phase].append((x, y)) | |
| fig, ax = plt.subplots(figsize=(12, 10)) | |
| handles = [] | |
| labels = [] | |
| for phase in sorted(phase_points.keys()): | |
| points = phase_points[phase] | |
| if not points: | |
| continue | |
| x_vals, y_vals = zip(*points) | |
| scatter = ax.scatter(x_vals, y_vals, | |
| c=[colors[phase]], | |
| s=250, | |
| alpha=0.8, | |
| edgecolors='white', | |
| linewidths=2) | |
| handles.append(scatter) | |
| labels.append(phase) | |
| if len(elements) == 2: | |
| ax.set_title(f"{elements[0]}-{elements[1]} Phase Diagram", | |
| fontsize=36, fontweight='bold', pad=20) | |
| elif len(elements) == 1: | |
| ax.set_title(f"{elements[0]} Phase Diagram", | |
| fontsize=36, fontweight='bold', pad=20) | |
| else: | |
| ax.set_title("Phase Diagram", fontsize=36, fontweight='bold', pad=20) | |
| ax.set_xlabel(x_label, fontsize=28) | |
| ax.set_ylabel("Temperature (K)", fontsize=28) | |
| all_temps = [entry['temperature'] for entry in entries] | |
| min_temp = min(all_temps) - 50 | |
| max_temp = max(all_temps) + 50 | |
| ax.set_xlim(-5, 105) | |
| ax.set_ylim(min_temp, max_temp) | |
| ax.tick_params(axis='both', labelsize=24) | |
| if handles: | |
| ax.legend(handles, labels, | |
| title="Phases", | |
| title_fontsize=24, | |
| fontsize=20, | |
| loc='upper center', | |
| bbox_to_anchor=(0.5, -0.15), | |
| ncol=min(4, len(handles)), | |
| frameon=True) | |
| plt.tight_layout() | |
| output_dir = os.path.dirname(output_base_path) if os.path.dirname(output_base_path) else "tmp" | |
| os.makedirs(output_dir, exist_ok=True) | |
| base_name = os.path.splitext(output_base_path)[0] | |
| png_path = f"{base_name}.png" | |
| svg_path = f"{base_name}.svg" | |
| plt.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.savefig(svg_path, format='svg', bbox_inches='tight') | |
| plt.close() | |
| return png_path, svg_path | |
| # --- TERNARY PHASE DIAGRAM --- | |
| def get_ternary_coords(percentages, elements): | |
| comp_top = percentages.get(elements[0], 0.0) | |
| comp_right = percentages.get(elements[2], 0.0) | |
| y = np.sqrt(3) * comp_top / 2 / 100.0 | |
| x = (comp_top / 2 + comp_right) / 100.0 | |
| return x, y | |
| def setup_ternary_axes(ax, elements): | |
| h = np.sqrt(3.0) * 0.5 | |
| ax.set_aspect('equal', 'datalim') | |
| ax.axis('off') | |
| ax.plot([0.0, 1.0], [0.0, 0.0], 'k-', lw=2) | |
| ax.plot([0.0, 0.5], [0.0, h], 'k-', lw=2) | |
| ax.plot([1.0, 0.5], [0.0, h], 'k-', lw=2) | |
| for i in range(1, 10): | |
| alpha = 0.3 | |
| ax.plot([i / 20.0, 1.0 - i / 20.0], [h * i / 10.0, h * i / 10.0], | |
| color='lightgray', lw=0.5, zorder=0, alpha=alpha) | |
| ax.plot([i / 20.0, i / 10.0], [h * i / 10.0, 0.0], | |
| color='lightgray', lw=0.5, zorder=0, alpha=alpha) | |
| ax.plot([0.5 + i / 20.0, i / 10.0], [h * (1.0 - i / 10.0), 0.0], | |
| color='lightgray', lw=0.5, zorder=0, alpha=alpha) | |
| ax.text(0.5, h + 0.05, str(elements[0]), fontsize=48, ha='center', va='bottom', fontweight='bold') | |
| ax.text(-0.05, -0.05, str(elements[1]), fontsize=48, ha='right', va='top', fontweight='bold') | |
| ax.text(1.05, -0.05, str(elements[2]), fontsize=48, ha='left', va='top', fontweight='bold') | |
| def create_ternary_phase_diagram(entries, output_base_path, temperature): | |
| if not entries: | |
| return None, None | |
| all_elements = set() | |
| for entry in entries: | |
| all_elements.update(entry['elements']) | |
| elements = sorted(all_elements) | |
| if len(elements) != 3: | |
| return None, None | |
| all_phases = set() | |
| for entry in entries: | |
| all_phases.update(entry['phases']) | |
| all_phases = sorted(all_phases) | |
| if len(all_phases) <= 1: | |
| phase_to_color = {list(all_phases)[0]: cm.rainbow(0.5)} if all_phases else {} | |
| else: | |
| phase_to_color = { | |
| phase: cm.rainbow(float(i) / (len(all_phases) - 1)) | |
| for i, phase in enumerate(all_phases) | |
| } | |
| phase_points = defaultdict(list) | |
| for entry in entries: | |
| x, y = get_ternary_coords(entry['percentages'], elements) | |
| for phase in entry['phases']: | |
| phase_points[phase].append((x, y)) | |
| fig, ax = plt.subplots(figsize=(14, 14)) | |
| setup_ternary_axes(ax, elements) | |
| system_name = "-".join(elements) | |
| ax.set_title(f"{system_name} Phase Diagram @ {int(temperature)}K", fontsize=40, fontweight='bold', pad=30) | |
| handles = [] | |
| labels = [] | |
| for phase in sorted(phase_points.keys()): | |
| points = phase_points[phase] | |
| if not points: | |
| continue | |
| x_vals, y_vals = zip(*points) | |
| scatter = ax.scatter(x_vals, y_vals, | |
| c=[phase_to_color[phase]], | |
| s=200, | |
| alpha=0.8, | |
| edgecolors='white', | |
| linewidths=1.5, | |
| zorder=10) | |
| handles.append(scatter) | |
| labels.append(phase) | |
| if handles: | |
| ax.legend(handles, labels, | |
| title="Phases", | |
| title_fontsize=28, | |
| fontsize=22, | |
| loc='upper center', | |
| bbox_to_anchor=(0.5, -0.05), | |
| ncol=min(3, len(handles)), | |
| frameon=True) | |
| plt.tight_layout() | |
| output_dir = os.path.dirname(output_base_path) if os.path.dirname(output_base_path) else "tmp" | |
| os.makedirs(output_dir, exist_ok=True) | |
| base_name = os.path.splitext(output_base_path)[0] | |
| png_path = f"{base_name}.png" | |
| svg_path = f"{base_name}.svg" | |
| plt.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white') | |
| plt.savefig(svg_path, format='svg', bbox_inches='tight') | |
| plt.close() | |
| return png_path, svg_path | |
| def generate_phase_diagram_from_jsonl(jsonl_path, output_base_path, is_ternary=False, temperature=None): | |
| entries = load_jsonl_data(jsonl_path) | |
| if not entries: | |
| return None, None | |
| if is_ternary: | |
| return create_ternary_phase_diagram(entries, output_base_path, temperature) | |
| else: | |
| return create_binary_phase_diagram(entries, output_base_path) | |
| # --- TERNARY COMPOSITION GRID --- | |
| def generate_ternary_compositions(step=20): | |
| compositions = [] | |
| for c1 in range(0, 101, step): | |
| for c2 in range(0, 101 - c1, step): | |
| c3 = 100 - c1 - c2 | |
| compositions.append((c1, c2, c3)) | |
| return compositions | |
| # --- EXPORT CHAT HISTORY --- | |
| def export_chat_history(history): | |
| """Export chat history to JSONL file.""" | |
| if not history: | |
| return None | |
| # Filter out example chat | |
| real_history = [msg for msg in history if msg not in EXAMPLE_CHAT] | |
| if not real_history: | |
| return None | |
| os.makedirs("tmp", exist_ok=True) | |
| output_path = "tmp/chat_history.jsonl" | |
| with open(output_path, "w") as f: | |
| current_user = None | |
| for msg in real_history: | |
| if isinstance(msg, dict): | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| # Skip image messages | |
| if isinstance(content, dict): | |
| continue | |
| # Extract text from list format if needed | |
| if isinstance(content, list): | |
| texts = [] | |
| for item in content: | |
| if isinstance(item, dict) and "text" in item: | |
| texts.append(item["text"]) | |
| elif isinstance(item, str): | |
| texts.append(item) | |
| content = " ".join(texts) | |
| # Skip loading dots | |
| if content in [".", "..", "..."]: | |
| continue | |
| if role == "user": | |
| current_user = content | |
| elif role == "assistant" and current_user: | |
| f.write(json.dumps({"user": current_user, "assistant": content}) + "\n") | |
| current_user = None | |
| return output_path | |
| # --- INFERENCE HELPER (shared by UI and API) --- | |
| def _run_single_inference(question, system_type="binary"): | |
| """Run a single inference and return the text answer.""" | |
| system_instruction = ( | |
| f"You are an expert in phase diagrams, thermodynamics, and materials science, " | |
| f"specializing in {system_type} alloy systems." | |
| ) | |
| response = model.run_inference(question, system_instruction, []) | |
| if not response: | |
| return "No response received." | |
| last_msg = response[-1] | |
| if isinstance(last_msg, dict): | |
| return last_msg.get('content', str(last_msg)) | |
| if isinstance(last_msg, (list, tuple)): | |
| return str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0]) | |
| return str(last_msg) | |
| def _run_batch_inference(questions, system_type="binary"): | |
| """Run multiple inferences back-to-back without delays. | |
| Returns a list of answer strings, one per question. | |
| """ | |
| system_instruction = ( | |
| f"You are an expert in phase diagrams, thermodynamics, and materials science, " | |
| f"specializing in {system_type} alloy systems." | |
| ) | |
| answers = [] | |
| for question in questions: | |
| try: | |
| response = model.run_inference(question, system_instruction, []) | |
| if not response: | |
| answers.append("No response received.") | |
| continue | |
| last_msg = response[-1] | |
| if isinstance(last_msg, dict): | |
| answers.append(last_msg.get('content', str(last_msg))) | |
| elif isinstance(last_msg, (list, tuple)): | |
| answers.append(str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0])) | |
| else: | |
| answers.append(str(last_msg)) | |
| except Exception as e: | |
| answers.append(f"Error: {e}") | |
| return answers | |
| # --- LOGIC FUNCTIONS --- | |
| def generate_prompt_text(task_name, e1, e2, e3, p1, p2, p3, temp, temp_min, temp_max, ternary_temp, *phases): | |
| target_phases = [p.strip() for p in phases if p and str(p).strip()] | |
| elements = [(e1, p1), (e2, p2), (e3, p3)] | |
| active_elements = [x for x in elements if x[0] and len(str(x[0]).strip()) > 0] | |
| if not active_elements: | |
| return "Please define elements in Step 1." | |
| count = len(active_elements) | |
| if task_name == "Phase name prediction": | |
| mixture_parts = [] | |
| for elem, perc in active_elements: | |
| p_val = perc if perc is not None else 0 | |
| mixture_parts.append(f"{elem} ({p_val}%)") | |
| mixture_str = " + ".join(mixture_parts) | |
| temp_val = temp if temp is not None else "___" | |
| return f"What phases form when {mixture_str} are mixed at {temp_val} K?" | |
| elif task_name == "Experimental condition prediction": | |
| element_names = [x[0] for x in active_elements] | |
| elements_str = " + ".join(element_names) | |
| ph_val = " + ".join(target_phases) if target_phases else "___" | |
| return f"Under what condition do {elements_str} form {ph_val}?" | |
| elif task_name == "Phase diagram prediction": | |
| element_names = [x[0] for x in active_elements] | |
| sys_str = "-".join(element_names) | |
| if count == 3: | |
| t_val = int(ternary_temp) if ternary_temp is not None else "___" | |
| return f"Draw the phase diagram for the {sys_str} system at {t_val} K." | |
| else: | |
| t_min = int(temp_min) if temp_min is not None else "___" | |
| t_max = int(temp_max) if temp_max is not None else "___" | |
| return f"Draw the phase diagram for the {sys_str} system in the temperature range {t_min}-{t_max} K." | |
| return "Please select a task in Step 2." | |
| def run_chat(prompt_text, e1, e2, e3, p1, p2, p3, history, task_name, temp_min, temp_max, ternary_temp, svg_state, is_running): | |
| """Main chat function. Returns (history, svg_path, is_running, btn_update).""" | |
| elements = [e1, e2, e3] | |
| active_elements = [x for x in elements if x and len(str(x).strip()) > 0] | |
| count = len(active_elements) | |
| # Clear example chat on first real input | |
| if history == EXAMPLE_CHAT or history is None: | |
| history = [] | |
| else: | |
| history = list(history) | |
| current_svg = svg_state | |
| # --- PHASE DIAGRAM BATCH LOGIC --- | |
| if task_name == "Phase diagram prediction" and count >= 1: | |
| os.makedirs("tmp", exist_ok=True) | |
| q_file_path = "tmp/q.jsonl" | |
| a_file_path = "tmp/a.jsonl" | |
| is_ternary = (count == 3) | |
| if is_ternary: | |
| t_single = int(ternary_temp) if ternary_temp else 800 | |
| t_start, t_end = t_single, t_single | |
| elements_str = "-".join(active_elements) | |
| diagram_base = f"tmp/{elements_str}_{t_single}K" | |
| else: | |
| t_start = int(temp_min) if temp_min else 300 | |
| t_end = int(temp_max) if temp_max else 1000 | |
| if t_start > t_end: | |
| t_start, t_end = t_end, t_start | |
| elements_str = "-".join(active_elements) | |
| diagram_base = f"tmp/{elements_str}_{t_start}-{t_end}K" | |
| system_type = "ternary" if is_ternary else ("binary" if count == 2 else "elemental") | |
| system_instruction = ( | |
| f"You are an expert in phase diagrams, thermodynamics, and materials science, " | |
| f"specializing in {system_type} alloy systems." | |
| ) | |
| generated_count = 0 | |
| error_count = 0 | |
| max_errors = 5 | |
| dot_step = 0 | |
| if is_ternary: | |
| history.append({"role": "assistant", "content": f"Starting {system_type} batch for **{elements_str}** at {t_single}K..."}) | |
| else: | |
| history.append({"role": "assistant", "content": f"Starting {system_type} batch for **{elements_str}** ({t_start}K - {t_end}K)..."}) | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| with open(q_file_path, "w") as q_file, open(a_file_path, "w") as a_file: | |
| for t in range(t_start, t_end + 1, 50): | |
| if is_ternary: | |
| compositions = generate_ternary_compositions(step=20) | |
| elif count == 2: | |
| compositions = [(100 - i, i) for i in range(0, 101, 20)] | |
| else: | |
| compositions = [(100,)] | |
| for comp in compositions: | |
| if error_count >= max_errors: | |
| history.append({"role": "assistant", "content": f"Stopped (errors). Completed: {generated_count}"}) | |
| yield history, current_svg, False, gr.update(value="βΆ", elem_classes=[]) | |
| return | |
| try: | |
| if is_ternary: | |
| c1, c2, c3 = comp | |
| mix_str = f"{active_elements[0]} ({c1}%) + {active_elements[1]} ({c2}%) + {active_elements[2]} ({c3}%)" | |
| elif count == 2: | |
| c1, c2 = comp | |
| if c2 == 0: | |
| mix_str = f"{active_elements[0]} (100%)" | |
| elif c2 == 100: | |
| mix_str = f"{active_elements[1]} (100%)" | |
| else: | |
| mix_str = f"{active_elements[0]} ({c1}%) + {active_elements[1]} ({c2}%)" | |
| else: | |
| mix_str = f"{active_elements[0]} (100%)" | |
| question = f"What phases form when {mix_str} are mixed at {t} K? Answer phase names only." | |
| history.append({"role": "user", "content": question}) | |
| history.append({"role": "assistant", "content": get_dots(dot_step)}) | |
| dot_step += 1 | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| q_file.write(json.dumps({"user": question}) + "\n") | |
| q_file.flush() | |
| time.sleep(1.5) | |
| answer = "Error" | |
| for attempt in range(3): | |
| try: | |
| history[-1] = {"role": "assistant", "content": get_dots(dot_step)} | |
| dot_step += 1 | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| temp_response = model.run_inference(question, system_instruction, []) | |
| if temp_response: | |
| last_msg = temp_response[-1] | |
| if isinstance(last_msg, dict): | |
| answer = last_msg.get('content', str(last_msg)) | |
| elif isinstance(last_msg, (list, tuple)): | |
| answer = str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0]) | |
| else: | |
| answer = str(last_msg) | |
| error_count = 0 | |
| break | |
| except Exception as e: | |
| if attempt < 2: | |
| history[-1] = {"role": "assistant", "content": f"Retry {attempt + 2}{get_dots(dot_step)}"} | |
| dot_step += 1 | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| time.sleep(3 * (attempt + 1)) | |
| else: | |
| answer = f"Error: {str(e)[:50]}" | |
| error_count += 1 | |
| history[-1] = {"role": "assistant", "content": answer} | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| a_file.write(json.dumps({"user": question, "answer": answer}) + "\n") | |
| a_file.flush() | |
| generated_count += 1 | |
| except Exception as e: | |
| error_count += 1 | |
| if history and history[-1].get("content", "").startswith("."): | |
| history[-1] = {"role": "assistant", "content": f"Error: {str(e)[:50]}"} | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| time.sleep(1) | |
| if error_count >= max_errors: | |
| break | |
| history.append({"role": "assistant", "content": f"Generating {system_type} diagram{get_dots(dot_step)}"}) | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| try: | |
| png_path, svg_path = generate_phase_diagram_from_jsonl( | |
| a_file_path, diagram_base, | |
| is_ternary=is_ternary, | |
| temperature=t_single if is_ternary else None | |
| ) | |
| if png_path and os.path.exists(png_path): | |
| current_svg = svg_path | |
| history[-1] = {"role": "assistant", "content": ( | |
| f"**Complete!** {generated_count} points generated.\n\n" | |
| f"System: {elements_str} ({system_type})" | |
| )} | |
| history.append({ | |
| "role": "assistant", | |
| "content": {"path": png_path, "alt_text": f"{elements_str} Phase Diagram"} | |
| }) | |
| yield history, current_svg, False, gr.update(value="βΆ", elem_classes=[]) | |
| else: | |
| history[-1] = {"role": "assistant", "content": f"Complete ({generated_count} points). Diagram failed."} | |
| yield history, current_svg, False, gr.update(value="βΆ", elem_classes=[]) | |
| except Exception as e: | |
| print(f"Diagram error: {e}") | |
| history[-1] = {"role": "assistant", "content": f"Complete. Diagram error: {str(e)[:50]}"} | |
| yield history, current_svg, False, gr.update(value="βΆ", elem_classes=[]) | |
| # --- STANDARD CHAT --- | |
| else: | |
| system_type = "ternary" if count == 3 else "binary" if count == 2 else "elemental" | |
| system_instruction = ( | |
| f"You are an expert in phase diagrams, thermodynamics, and materials science, " | |
| f"specializing in {system_type} alloy systems." | |
| ) | |
| history.append({"role": "user", "content": prompt_text}) | |
| dot_step = 0 | |
| history.append({"role": "assistant", "content": get_dots(dot_step)}) | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| for i in range(2): | |
| time.sleep(0.3) | |
| dot_step += 1 | |
| history[-1] = {"role": "assistant", "content": get_dots(dot_step)} | |
| yield history, current_svg, True, gr.update(value="β ", elem_classes=["stop-btn"]) | |
| try: | |
| temp_response = model.run_inference(prompt_text, system_instruction, []) | |
| if temp_response: | |
| last_msg = temp_response[-1] | |
| if isinstance(last_msg, dict): | |
| answer = last_msg.get('content', str(last_msg)) | |
| elif isinstance(last_msg, (list, tuple)): | |
| answer = str(last_msg[1]) if len(last_msg) > 1 else str(last_msg[0]) | |
| else: | |
| answer = str(last_msg) | |
| else: | |
| answer = "No response received." | |
| history[-1] = {"role": "assistant", "content": answer} | |
| yield history, current_svg, False, gr.update(value="βΆ", elem_classes=[]) | |
| except Exception as e: | |
| history[-1] = {"role": "assistant", "content": f"Error: {str(e)[:100]}"} | |
| yield history, current_svg, False, gr.update(value="βΆ", elem_classes=[]) | |
| # --- UI Helpers --- | |
| def add_next_phase(count): | |
| new_count = min(count + 1, MAX_PHASES) | |
| updates = [gr.update(visible=(i < new_count)) for i in range(MAX_PHASES)] | |
| return [new_count] + updates | |
| def update_perc_visibility(e2_text, e3_text): | |
| """Update percentage field visibility and labels based on elements.""" | |
| has_e2 = e2_text and len(str(e2_text).strip()) > 0 | |
| has_e3 = e3_text and len(str(e3_text).strip()) > 0 | |
| if has_e3: | |
| # Ternary: perc1 and perc2 editable, perc3 auto | |
| e2_label = f"{str(e2_text).upper()} %" if has_e2 else "ELEMENT 2 %" | |
| e3_label = f"{str(e3_text).upper()} % (AUTO)" | |
| return (gr.update(interactive=True), | |
| gr.update(visible=True, interactive=True, label=e2_label), | |
| gr.update(visible=True, interactive=False, label=e3_label)) | |
| elif has_e2: | |
| # Binary: perc1 editable, perc2 auto | |
| e2_label = f"{str(e2_text).upper()} % (AUTO)" | |
| return (gr.update(interactive=True), | |
| gr.update(visible=True, interactive=False, label=e2_label), | |
| gr.update(visible=False)) | |
| else: | |
| # Single element: only perc1 visible | |
| return (gr.update(interactive=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False)) | |
| def auto_calculate_balance(p1, p2, e2_text, e3_text): | |
| p1, p2 = p1 or 0, p2 or 0 | |
| has_e2 = e2_text and len(str(e2_text).strip()) > 0 | |
| has_e3 = e3_text and len(str(e3_text).strip()) > 0 | |
| if has_e3: | |
| # Ternary: auto-calculate perc3 | |
| return gr.update(), gr.update(value=max(0, 100 - p1 - p2)) | |
| elif has_e2: | |
| # Binary: auto-calculate perc2 | |
| return gr.update(value=max(0, 100 - p1)), gr.update() | |
| return gr.update(), gr.update() | |
| def update_label(element_name): | |
| if element_name and len(str(element_name).strip()) > 0: | |
| return gr.update(label=f"{str(element_name).upper()} %") | |
| return gr.update(label="ELEMENT 1 %") | |
| def update_elem2_enabled(e1): | |
| """Enable elem2 when elem1 has a value, disable and clear when elem1 is empty.""" | |
| has_e1 = e1 and len(str(e1).strip()) > 0 | |
| if has_e1: | |
| return gr.update(interactive=True) | |
| else: | |
| return gr.update(interactive=False, value=None) | |
| def update_elem3_enabled(e2): | |
| """Enable elem3 when elem2 has a value, disable and clear when elem2 is empty.""" | |
| has_e2 = e2 and len(str(e2).strip()) > 0 | |
| if has_e2: | |
| return gr.update(interactive=True) | |
| else: | |
| return gr.update(interactive=False, value=None) | |
| def update_diag_temp_visibility(e1, e2, e3): | |
| elements = [e1, e2, e3] | |
| active = [x for x in elements if x and len(str(x).strip()) > 0] | |
| count = len(active) | |
| if count == 3: | |
| return gr.update(visible=False), gr.update(visible=True) | |
| else: | |
| return gr.update(visible=True), gr.update(visible=False) | |
| def activate_step2(e1, e2, e3): | |
| """Activate STEP 02 when at least one element is entered.""" | |
| elements = [e1, e2, e3] | |
| has_element = any(x and len(str(x).strip()) > 0 for x in elements) | |
| if has_element: | |
| return gr.update(elem_classes=["step-frame", "step-active"]) | |
| return gr.update(elem_classes=["step-frame", "step-inactive"]) | |
| def check_can_generate(temp, perc1, temp_min, temp_max, ternary_temp, task_name, e1, e2, e3, *phases): | |
| is_ready = False | |
| valid_phases = [p for p in phases if p and str(p).strip()] | |
| elements = [e1, e2, e3] | |
| active = [x for x in elements if x and len(str(x).strip()) > 0] | |
| count = len(active) | |
| if count == 0: | |
| return (gr.update(interactive=False), | |
| gr.update(elem_classes=["step-frame", "step-inactive"])) | |
| if task_name == "Phase name prediction": | |
| is_ready = temp is not None and perc1 is not None | |
| elif task_name == "Experimental condition prediction": | |
| is_ready = len(valid_phases) > 0 | |
| elif task_name == "Phase diagram prediction": | |
| if count == 3: | |
| is_ready = ternary_temp is not None | |
| else: | |
| is_ready = temp_min is not None and temp_max is not None | |
| return (gr.update(interactive=is_ready), | |
| gr.update(elem_classes=["step-frame", "step-active"] if is_ready else ["step-frame", "step-inactive"])) | |
| def select_task_phase(): | |
| return ("Phase name prediction", gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), | |
| gr.update(variant="primary"), gr.update(variant="secondary"), gr.update(variant="secondary"), | |
| gr.update(), gr.update()) | |
| def select_task_cond(): | |
| return ("Experimental condition prediction", gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), | |
| gr.update(variant="secondary"), gr.update(variant="primary"), gr.update(variant="secondary"), | |
| gr.update(), gr.update()) | |
| def select_task_diag(e1, e2, e3): | |
| elements = [e1, e2, e3] | |
| active = [x for x in elements if x and len(str(x).strip()) > 0] | |
| count = len(active) | |
| if count == 3: | |
| binary_vis, ternary_vis = False, True | |
| else: | |
| binary_vis, ternary_vis = True, False | |
| return ("Phase diagram prediction", gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), | |
| gr.update(variant="secondary"), gr.update(variant="secondary"), gr.update(variant="primary"), | |
| gr.update(visible=binary_vis), gr.update(visible=ternary_vis)) | |
| # ===================================================================== | |
| # MCP API FUNCTIONS β clean endpoints for MCP tool exposure | |
| # ===================================================================== | |
| def predict_phases_api(element1: str, element2: str = "", element3: str = "", | |
| percent1: float = 100, percent2: float = 0, percent3: float = 0, | |
| temperature_k: float = 800) -> str: | |
| """Predict which phases form for an alloy at a given composition and temperature. | |
| Provide 1-3 elements with their atomic percentages (must sum to 100) and a | |
| temperature in Kelvin. Returns the predicted equilibrium phase name(s). | |
| To build a phase diagram, call this tool many times (e.g. 50+ calls) across a | |
| grid of compositions and temperatures, then pass all results to the | |
| plot_phase_diagram tool to render the diagram. | |
| For a binary system (2 elements), vary element2 percent from 0 to 100 in steps | |
| of ~10-20, and temperature from min to max in steps of ~50K. | |
| For a ternary system (3 elements), fix temperature and vary compositions across | |
| the triangle in steps of ~20%. | |
| Example: element1="Gold", percent1=50, element2="Silver", percent2=50, temperature_k=800 | |
| """ | |
| elements = [(element1, percent1)] | |
| if element2 and str(element2).strip(): | |
| elements.append((element2, percent2)) | |
| if element3 and str(element3).strip(): | |
| elements.append((element3, percent3)) | |
| total = sum(p for _, p in elements) | |
| if abs(total - 100) > 0.5: | |
| return f"Error: Percentages sum to {total}, expected 100." | |
| mix_parts = [f"{el} ({p}%)" for el, p in elements] | |
| question = f"What phases form when {' + '.join(mix_parts)} are mixed at {int(temperature_k)} K?" | |
| sys_type = "ternary" if len(elements) == 3 else "binary" if len(elements) == 2 else "elemental" | |
| return _run_single_inference(question, sys_type) | |
| def predict_conditions_api(element1: str, element2: str = "", element3: str = "", | |
| target_phases: str = "") -> str: | |
| """Predict experimental conditions for given elements to form target phases. | |
| Provide 1-3 elements and one or more target phase names (comma- or +-separated). | |
| Returns predicted conditions (composition, temperature, etc.). | |
| Example: element1="Gold", element2="Silver", target_phases="FCC_A1" | |
| """ | |
| elems = [element1] | |
| if element2 and str(element2).strip(): | |
| elems.append(element2) | |
| if element3 and str(element3).strip(): | |
| elems.append(element3) | |
| if not target_phases or not str(target_phases).strip(): | |
| return "Error: No target phases provided." | |
| phases_clean = [p.strip() for p in re.split(r'[,+]', target_phases) if p.strip()] | |
| elements_str = " + ".join(elems) | |
| phases_str = " + ".join(phases_clean) | |
| question = f"Under what condition do {elements_str} form {phases_str}?" | |
| sys_type = "ternary" if len(elems) == 3 else "binary" if len(elems) == 2 else "elemental" | |
| return _run_single_inference(question, sys_type) | |
| def plot_phase_diagram_api(data_points_json: str) -> str: | |
| """Generate a phase diagram image from pre-collected data points. | |
| This tool renders a phase diagram from data you have already collected by | |
| calling predict_phases multiple times. It does NOT run any model inference | |
| itself β it only plots. | |
| WORKFLOW to build a phase diagram: | |
| 1. Call predict_phases ~50+ times across a grid of compositions/temperatures. | |
| 2. Collect all results into the JSON format below. | |
| 3. Call this tool ONCE with that JSON to render the diagram. | |
| Input: a JSON string with this structure: | |
| { | |
| "elements": ["Gold", "Silver"], | |
| "temperature_k": 800, // (optional, used as title for ternary) | |
| "points": [ | |
| { | |
| "composition": {"Gold": 80, "Silver": 20}, | |
| "temperature_k": 800, | |
| "phases": "FCC_A1" | |
| }, | |
| { | |
| "composition": {"Gold": 50, "Silver": 50}, | |
| "temperature_k": 1000, | |
| "phases": "FCC_A1 + LIQUID" | |
| } | |
| ] | |
| } | |
| - "elements": list of 2 or 3 element names (must match keys in composition dicts) | |
| - "points": list of data points, each with composition (dict of element->percent), | |
| temperature_k, and phases (string, use " + " to separate multi-phase regions) | |
| - For ternary systems, include "temperature_k" at top level for the diagram title | |
| Returns: a JSON string with the file paths of the generated PNG and SVG images, | |
| or an error message. | |
| """ | |
| try: | |
| data = json.loads(data_points_json) | |
| except json.JSONDecodeError as e: | |
| return json.dumps({"error": f"Invalid JSON: {e}"}) | |
| elements = data.get("elements", []) | |
| points = data.get("points", []) | |
| if not elements or len(elements) < 2: | |
| return json.dumps({"error": "Provide at least 2 element names in 'elements' list."}) | |
| if not points: | |
| return json.dumps({"error": "No data points provided in 'points' list."}) | |
| # Convert to internal entry format | |
| entries = [] | |
| for pt in points: | |
| comp = pt.get("composition", {}) | |
| temp = pt.get("temperature_k", 800) | |
| phase_str = pt.get("phases", "") | |
| if not phase_str or not str(phase_str).strip(): | |
| continue | |
| # Parse phases from the string | |
| phases = extract_phases(str(phase_str)) | |
| if not phases: | |
| phases = [normalize_phase_name(str(phase_str))] | |
| phases = [normalize_phase_name(p) for p in phases] | |
| # Build percentages dict β ensure all elements are present | |
| percentages = {} | |
| for el in elements: | |
| percentages[el] = float(comp.get(el, 0)) | |
| entries.append({ | |
| 'elements': [el for el in elements if percentages.get(el, 0) > 0] or [elements[0]], | |
| 'percentages': percentages, | |
| 'temperature': float(temp), | |
| 'phases': phases, | |
| }) | |
| if not entries: | |
| return json.dumps({"error": "No valid data points after parsing."}) | |
| # Determine diagram type | |
| is_ternary = len(elements) == 3 | |
| elements_str = "-".join(sorted(elements)) | |
| os.makedirs("tmp", exist_ok=True) | |
| if is_ternary: | |
| temp_label = int(data.get("temperature_k", entries[0]['temperature'])) | |
| diagram_base = f"tmp/{elements_str}_{temp_label}K" | |
| png_path, svg_path = create_ternary_phase_diagram(entries, diagram_base, temp_label) | |
| else: | |
| temps = [e['temperature'] for e in entries] | |
| t_min, t_max = int(min(temps)), int(max(temps)) | |
| diagram_base = f"tmp/{elements_str}_{t_min}-{t_max}K" | |
| png_path, svg_path = create_binary_phase_diagram(entries, diagram_base) | |
| if png_path and os.path.exists(png_path): | |
| result = { | |
| "status": "success", | |
| "system": elements_str, | |
| "type": "ternary" if is_ternary else "binary", | |
| "num_points": len(entries), | |
| "png_path": png_path, | |
| "svg_path": svg_path, | |
| } | |
| return json.dumps(result) | |
| else: | |
| return json.dumps({"error": "Diagram generation failed."}) | |
| # --- UI CONSTRUCTION --- | |
| with gr.Blocks() as demo: | |
| # State | |
| svg_state = gr.State(None) | |
| is_running = gr.State(False) | |
| with gr.Column(elem_classes=["title-area"]): | |
| gr.Markdown("# aLLoyM") | |
| gr.Markdown("Interactive Alloy Design & Phase Diagram Assistant", elem_classes=["subtitle"]) | |
| gr.Markdown( | |
| "Ask questions to aLLoyM, our Mistral-based model fine-tuned on Computational Phase Diagram Database (CPDDB) " | |
| "and assessments based on CALPHAD (CALculation of PHAse Diagrams). " | |
| "Check out our [**paper**](https://www.nature.com/articles/s41524-026-01966-6) for more information. " | |
| "You can also download the [**weights**](https://huggingface.co/Playingyoyo/aLLoyM) and fine-tune, " | |
| "or download the [**data**](https://huggingface.co/datasets/Playingyoyo/aLLoyM-dataset). " | |
| "Like all AI models, aLLoyM may hallucinate outside its training scopeβbut this also means you can explore alloy compositions that have never been experimentally tested.", | |
| elem_classes=["description"] | |
| ) | |
| with gr.Row(elem_classes=["main-row"]): | |
| with gr.Column(scale=4): | |
| with gr.Column(elem_classes=["step-frame", "step-active"]) as step1_frame: | |
| gr.Markdown("STEP 01 // CHOOSE ELEMENTS", elem_classes=["step-label"]) | |
| gr.Markdown("*Type, or choose from below for more reliable answers*", elem_classes=["step-hint"]) | |
| with gr.Column(): | |
| elem1 = gr.Dropdown(label="ELEMENT 1", choices=ELEMENT_CHOICES, value=None, allow_custom_value=True) | |
| elem2 = gr.Dropdown(label="ELEMENT 2 (optional)", choices=ELEMENT_CHOICES, value=None, allow_custom_value=True, interactive=False) | |
| elem3 = gr.Dropdown(label="ELEMENT 3 (optional)", choices=ELEMENT_CHOICES, value=None, allow_custom_value=True, interactive=False) | |
| with gr.Column(elem_classes=["step-frame", "step-inactive"]) as step2_frame: | |
| gr.Markdown("STEP 02 // CHOOSE TASK AND INSERT DETAILS", elem_classes=["step-label"]) | |
| current_task = gr.State("Phase name prediction") | |
| with gr.Column(): | |
| btn_phase = gr.Button("PHASE PREDICTION", elem_classes=["task-card"], variant="primary") | |
| btn_cond = gr.Button("EXP. CONDITIONS", elem_classes=["task-card"], variant="secondary") | |
| btn_diag = gr.Button("PHASE DIAGRAM", elem_classes=["task-card"], variant="secondary") | |
| with gr.Group(visible=True) as group_phase: | |
| with gr.Column(): | |
| temp_input = gr.Number(label="TEMPERATURE (K)", value=800, step=50) | |
| perc1 = gr.Number(label="ELEMENT 1 %", value=100) | |
| perc2 = gr.Number(label="ELEMENT 2 % (AUTO)", value=0, interactive=False, visible=False) | |
| perc3 = gr.Number(label="ELEMENT 3 % (AUTO)", value=0, interactive=False, visible=False) | |
| phase_inputs = [] | |
| phase_count = gr.State(1) | |
| with gr.Group(visible=False) as group_cond: | |
| gr.Markdown("*Type, or choose from below for more reliable answers*", elem_classes=["step-hint"]) | |
| for i in range(MAX_PHASES): | |
| t = gr.Dropdown(choices=PHASE_CHOICES, value=None, allow_custom_value=True, show_label=False, visible=(i==0), container=False) | |
| phase_inputs.append(t) | |
| add_phase_btn = gr.Button("+", variant="secondary") | |
| with gr.Group(visible=False) as group_diag: | |
| with gr.Column(visible=True) as binary_temp_group: | |
| temp_min = gr.Number(label="MIN TEMPERATURE (K)", value=300, step=50) | |
| temp_max = gr.Number(label="MAX TEMPERATURE (K)", value=1000, step=50) | |
| with gr.Column(visible=False) as ternary_temp_group: | |
| ternary_temp = gr.Number(label="TEMPERATURE (K)", value=800, step=50) | |
| with gr.Column(scale=6): | |
| with gr.Column(elem_classes=["step-frame", "step-inactive"], elem_id="right-panel") as chat_frame: | |
| gr.Markdown("STEP 03 // CHAT", elem_classes=["step-label"]) | |
| chatbot = gr.Chatbot(label="", show_label=False, height=500, layout="panel", value=EXAMPLE_CHAT) | |
| with gr.Row(): | |
| download_svg_btn = gr.DownloadButton("π₯ DOWNLOAD SVG", visible=True, scale=1) | |
| download_chat_btn = gr.DownloadButton("π₯ DOWNLOAD CHAT (JSONL)", visible=True, scale=1) | |
| with gr.Row(elem_classes=["input-area"]): | |
| prompt_preview = gr.Textbox( | |
| label="USER INPUT (AUTO-GENERATED)", | |
| value="Select a task...", | |
| lines=2, | |
| interactive=False, | |
| scale=8 | |
| ) | |
| send_btn = gr.Button("βΆ", scale=1, interactive=False, elem_id="send-btn") | |
| # --- Events (all with api_name=False to hide from MCP) --- | |
| for elem in [elem1, elem2, elem3]: | |
| elem.change(fn=activate_step2, inputs=[elem1, elem2, elem3], outputs=step2_frame, show_progress="hidden", api_name=False) | |
| # Enable elem2 when elem1 is filled, enable elem3 when elem2 is filled | |
| elem1.change(fn=update_elem2_enabled, inputs=elem1, outputs=elem2, show_progress="hidden", api_name=False) | |
| elem2.change(fn=update_elem3_enabled, inputs=elem2, outputs=elem3, show_progress="hidden", api_name=False) | |
| elem1.change(update_label, elem1, perc1, show_progress="hidden", api_name=False) | |
| # Update percentage visibility when elem2 or elem3 changes | |
| elem2.change(fn=update_perc_visibility, inputs=[elem2, elem3], outputs=[perc1, perc2, perc3], api_name=False) | |
| elem3.change(fn=update_perc_visibility, inputs=[elem2, elem3], outputs=[perc1, perc2, perc3], api_name=False) | |
| for elem in [elem1, elem2, elem3]: | |
| elem.change(fn=update_diag_temp_visibility, | |
| inputs=[elem1, elem2, elem3], | |
| outputs=[binary_temp_group, ternary_temp_group], | |
| show_progress="hidden", api_name=False) | |
| perc1.change(fn=auto_calculate_balance, inputs=[perc1, perc2, elem2, elem3], outputs=[perc2, perc3], api_name=False) | |
| perc2.change(fn=auto_calculate_balance, inputs=[perc1, perc2, elem2, elem3], outputs=[perc2, perc3], api_name=False) | |
| btn_phase.click(select_task_phase, None, [current_task, group_phase, group_cond, group_diag, btn_phase, btn_cond, btn_diag, binary_temp_group, ternary_temp_group], api_name=False) | |
| btn_cond.click(select_task_cond, None, [current_task, group_phase, group_cond, group_diag, btn_phase, btn_cond, btn_diag, binary_temp_group, ternary_temp_group], api_name=False) | |
| btn_diag.click(select_task_diag, [elem1, elem2, elem3], [current_task, group_phase, group_cond, group_diag, btn_phase, btn_cond, btn_diag, binary_temp_group, ternary_temp_group], api_name=False) | |
| add_phase_btn.click(fn=add_next_phase, inputs=phase_count, outputs=[phase_count] + phase_inputs, api_name=False) | |
| all_prompt_inputs = [current_task, elem1, elem2, elem3, perc1, perc2, perc3, temp_input, temp_min, temp_max, ternary_temp] + phase_inputs | |
| check_inputs = [temp_input, perc1, temp_min, temp_max, ternary_temp, current_task, elem1, elem2, elem3] + phase_inputs | |
| for inp in all_prompt_inputs: | |
| inp.change(fn=generate_prompt_text, inputs=all_prompt_inputs, outputs=prompt_preview, api_name=False) | |
| inp.change(fn=check_can_generate, | |
| inputs=check_inputs, | |
| outputs=[send_btn, chat_frame], | |
| show_progress="hidden", api_name=False) | |
| # Main send button (UI only, not exposed as MCP) | |
| send_btn.click( | |
| fn=run_chat, | |
| inputs=[prompt_preview, elem1, elem2, elem3, perc1, perc2, perc3, chatbot, current_task, temp_min, temp_max, ternary_temp, svg_state, is_running], | |
| outputs=[chatbot, svg_state, is_running, send_btn], | |
| show_progress="hidden", | |
| api_name=False | |
| ) | |
| # Update download button when SVG is available | |
| def update_svg_download(svg_path): | |
| if svg_path and os.path.exists(svg_path): | |
| return gr.DownloadButton(value=svg_path, visible=True) | |
| return gr.DownloadButton(value=None, visible=True) | |
| svg_state.change(fn=update_svg_download, inputs=svg_state, outputs=download_svg_btn, api_name=False) | |
| # Update chat download button when chat changes | |
| def update_chat_download(history): | |
| path = export_chat_history(history) | |
| if path and os.path.exists(path): | |
| return gr.DownloadButton(value=path, visible=True) | |
| return gr.DownloadButton(value=None, visible=True) | |
| chatbot.change(fn=update_chat_download, inputs=chatbot, outputs=download_chat_btn, api_name=False) | |
| # ================================================================= | |
| # MCP API ENDPOINTS β these are the only tools exposed via MCP | |
| # ================================================================= | |
| # Hidden components for API inputs/outputs | |
| with gr.Row(visible=False): | |
| api_out = gr.Textbox() | |
| # --- Task 1: Phase Prediction --- | |
| api_e1 = gr.Textbox() | |
| api_e2 = gr.Textbox() | |
| api_e3 = gr.Textbox() | |
| api_p1 = gr.Number() | |
| api_p2 = gr.Number() | |
| api_p3 = gr.Number() | |
| api_temp = gr.Number() | |
| api_btn_phase = gr.Button() | |
| # --- Task 2: Condition Prediction --- | |
| api_cond_e1 = gr.Textbox() | |
| api_cond_e2 = gr.Textbox() | |
| api_cond_e3 = gr.Textbox() | |
| api_cond_phases = gr.Textbox() | |
| api_btn_cond = gr.Button() | |
| # --- Task 3: Plot Phase Diagram (from pre-collected data) --- | |
| api_plot_data = gr.Textbox() | |
| api_btn_plot = gr.Button() | |
| api_out2 = gr.Textbox() | |
| api_out3 = gr.Textbox() | |
| api_btn_phase.click( | |
| fn=predict_phases_api, | |
| inputs=[api_e1, api_e2, api_e3, api_p1, api_p2, api_p3, api_temp], | |
| outputs=api_out, | |
| api_name="predict_phases" | |
| ) | |
| api_btn_cond.click( | |
| fn=predict_conditions_api, | |
| inputs=[api_cond_e1, api_cond_e2, api_cond_e3, api_cond_phases], | |
| outputs=api_out2, | |
| api_name="predict_conditions" | |
| ) | |
| api_btn_plot.click( | |
| fn=plot_phase_diagram_api, | |
| inputs=[api_plot_data], | |
| outputs=api_out3, | |
| api_name="plot_phase_diagram" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=theme.get_theme(), css=css.get_css(), mcp_server=True) |