Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import os, re, json, textwrap, traceback | |
| from decimal import Decimal | |
| from typing import List, Tuple | |
| from flask import Flask, request, render_template, url_for | |
| from flask_cors import CORS | |
| from rag_mini import ( | |
| search, | |
| ensure_ready, | |
| DEFAULT_TOPK, | |
| rag_debug_info, # for /debug/rag | |
| ) | |
| # ------------ LLM config ------------ | |
| LLM_PROVIDER = (os.getenv("LLM_PROVIDER") or "openai").strip().lower() | |
| LLM_MODEL = (os.getenv("LLM_MODEL") or "gpt-4o-mini").strip() | |
| LLM_API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("LLM_API_KEY") | |
| OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") # optional (Azure/proxy) | |
| app = Flask(__name__) | |
| app.secret_key = os.getenv("FLASK_SECRET_KEY", "change-me-please") | |
| CORS(app) | |
| SYSTEM_RULES = """You are MaterialMind, a materials-selection assistant. | |
| Return two things: | |
| 1) JSON with a ranked shortlist: | |
| { | |
| "candidates": [ | |
| { | |
| "name": "string", | |
| "score": 0, // 0..400 (sum of 4 independent 0..100 utilities) | |
| "score_pct": 0, // 0..100 normalized display | |
| "reasons": ["..."], | |
| "tradeoffs": ["..."], | |
| "citations": ["[1]", "[2]"] | |
| } | |
| ] | |
| } | |
| 2) After the JSON, provide 3–6 concise bullets on trade-offs. | |
| Rules: | |
| - Use only the provided context; cite with [1], [2]. No fabrication. | |
| - Utilities per criterion are in [0,1]. Cost utility increases as cost decreases. | |
| - Weights (performance, stability, cost, availability) are independent 0..100 (NOT normalized). | |
| """ | |
| ANSWER_TEMPLATE = """User constraints: | |
| - Environment: {environment} | |
| - Temperature: {temperature} | |
| - Min UTS (MPa): {min_uts} | |
| - Max density (g/cm^3): {max_density} | |
| - Budget: {budget} | |
| - Process: {process} | |
| Independent priorities (0..100 each): | |
| - performance={w_perf}, stability={w_stab}, cost={w_cost}, availability={w_avail} | |
| Question: | |
| For {environment} at {temperature}, shortlist materials that meet UTS ≥ {min_uts} MPa and density ≤ {max_density} g/cm^3. | |
| Consider budget={budget} and process={process}. Rank by performance, stability, cost, and availability. | |
| Context snippets (numbered): | |
| {context} | |
| Citations: | |
| {citations} | |
| Now, first output ONLY the JSON block (no preamble). Then the short narrative. | |
| """ | |
| # ---------- helpers ---------- | |
| def to_dec(x, default:int)->Decimal: | |
| try: | |
| s = (x or "").strip() | |
| return Decimal(s if s else str(default)) | |
| except Exception: | |
| return Decimal(default) | |
| def format_context(hits: List[Tuple[str, str]]): | |
| blocks, cites = [], [] | |
| for i,(text,cite) in enumerate(hits,1): | |
| snippet = textwrap.shorten((text or "").replace("\n"," "), width=450, placeholder=" …") | |
| blocks.append(f"[{i}] {snippet}") | |
| cites.append(f"[{i}] {cite}") | |
| return "\n".join(blocks), "\n".join(cites) | |
| def extract_json_block(text:str): | |
| if not text: | |
| return None | |
| m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.S|re.I) | |
| blob = m.group(1) if m else None | |
| if not blob: | |
| s = text | |
| start = s.find("{") | |
| while start != -1: | |
| depth = 0 | |
| for j in range(start, len(s)): | |
| ch = s[j] | |
| if ch == "{": depth += 1 | |
| elif ch == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| blob = s[start:j+1] | |
| break | |
| if blob: break | |
| start = s.find("{", start+1) | |
| if not blob: | |
| return None | |
| try: | |
| return json.loads(blob) | |
| except Exception: | |
| return None | |
| # ---------- LLM caller ---------- | |
| def call_llm_cloud(system:str, user:str)->str: | |
| prov = LLM_PROVIDER; model = LLM_MODEL | |
| if prov in ("openai","oai"): | |
| from openai import OpenAI | |
| client = OpenAI(api_key=LLM_API_KEY, base_url=OPENAI_BASE_URL or None) | |
| r = client.chat.completions.create( | |
| model=model, | |
| temperature=0.2, | |
| max_tokens=1200, | |
| messages=[{"role":"system","content":system}, | |
| {"role":"user","content":user}], | |
| ) | |
| return r.choices[0].message.content | |
| elif prov in ("together","tg"): | |
| from together import Together | |
| client = Together(api_key=LLM_API_KEY) | |
| r = client.chat.completions.create( | |
| model=model, temperature=0.2, max_tokens=1200, | |
| messages=[{"role":"system","content":system},{"role":"user","content":user}], | |
| ) | |
| return r.choices[0].message.content | |
| else: | |
| from huggingface_hub import InferenceClient | |
| hf_token = LLM_API_KEY or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| client = InferenceClient(model=model, token=hf_token) | |
| try: | |
| out = client.chat_completion( | |
| messages=[{"role":"system","content":system},{"role":"user","content":user}], | |
| max_tokens=1200, temperature=0.2, | |
| ) | |
| return out.choices[0].message["content"] | |
| except Exception: | |
| return client.text_generation( | |
| prompt=f"{system}\n\n{user}\n", max_new_tokens=1200, temperature=0.2 | |
| ) | |
| # ---------- routes ---------- | |
| def healthz(): | |
| return { | |
| "ok": True, | |
| "provider": LLM_PROVIDER, | |
| "model": LLM_MODEL, | |
| "has_api_key": bool(LLM_API_KEY), | |
| }, 200 | |
| def debug_rag(): | |
| return rag_debug_info(), 200 | |
| def index(): | |
| return render_template("index.html", default_k=DEFAULT_TOPK) | |
| def recommend(): | |
| try: | |
| environment = request.form.get("environment","").strip() or "seawater" | |
| temperature = request.form.get("temperature","").strip() or "20–25 °C" | |
| min_uts = request.form.get("min_uts","").strip() or "0" | |
| max_density = request.form.get("max_density","").strip() or "100" | |
| budget = request.form.get("budget","").strip() or "open" | |
| process = request.form.get("process","").strip() or "any" | |
| w_perf = to_dec(request.form.get("w_perf"), 75) | |
| w_stab = to_dec(request.form.get("w_stab"), 100) | |
| w_cost = to_dec(request.form.get("w_cost"), 75) | |
| w_avail = to_dec(request.form.get("w_avail"), 75) | |
| try: k = int(request.form.get("k", DEFAULT_TOPK)) | |
| except: k = DEFAULT_TOPK | |
| question = (f"For {environment} at {temperature}, shortlist materials that meet " | |
| f"UTS ≥ {min_uts} MPa and density ≤ {max_density} g/cm^3. " | |
| f"Consider budget={budget} and process={process}. " | |
| f"Rank by performance, stability, cost, and availability.") | |
| # RAG search (never crash UI) | |
| try: | |
| hits = search(question, k=k) | |
| rag_error = "" | |
| except Exception as e: | |
| app.logger.exception("RAG search failed") | |
| hits = [] | |
| rag_error = f"RAG error: {type(e).__name__}: {e}" | |
| ctx, cites = format_context(hits) | |
| user_prompt = ANSWER_TEMPLATE.format( | |
| environment=environment, temperature=temperature, | |
| min_uts=min_uts, max_density=max_density, budget=budget, process=process, | |
| w_perf=str(int(w_perf)), w_stab=str(int(w_stab)), | |
| w_cost=str(int(w_cost)), w_avail=str(int(w_avail)), | |
| context=ctx, citations=cites | |
| ) | |
| # LLM call (never crash UI) | |
| raw = "" | |
| try: | |
| raw = call_llm_cloud(SYSTEM_RULES, user_prompt) | |
| except Exception as e: | |
| app.logger.exception("LLM call failed") | |
| raw = f"ERROR calling LLM ({LLM_PROVIDER}:{LLM_MODEL}): {type(e).__name__}: {e}" | |
| parsed = extract_json_block(raw) if raw else None | |
| candidates = (parsed or {}).get("candidates", []) if parsed else [] | |
| if rag_error: | |
| raw = f"{rag_error}\n\n{raw}" | |
| return render_template( | |
| "results.html", | |
| candidates=candidates, | |
| citations=(cites.splitlines() if cites else []), | |
| environment=environment, | |
| temperature=temperature, | |
| raw_output=raw or "", | |
| ) | |
| except Exception as e: | |
| app.logger.exception("recommend() hard failure") | |
| tb = traceback.format_exc() | |
| return render_template( | |
| "results.html", | |
| candidates=[], | |
| citations=[], | |
| environment="(unknown)", | |
| temperature="(unknown)", | |
| raw_output=f"FATAL: {type(e).__name__}: {e}\n\n{tb}", | |
| ), 200 | |
| if __name__ == "__main__": | |
| ensure_ready() | |
| port = int(os.getenv("PORT", "7860")) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |