| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from typing import List, Optional, Dict, Any |
| import uvicorn |
| import os |
| import sys |
| import torch |
| import json |
| import logging |
| import networkx as nx |
| from networkx.readwrite import json_graph |
| import numpy as np |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class NumpyEncoder(json.JSONEncoder): |
| _nan_warning_logged = False |
|
|
| def default(self, obj): |
| if isinstance(obj, np.integer): |
| return int(obj) |
| if isinstance(obj, np.floating): |
| |
| f = float(obj) |
| if not np.isfinite(f): |
| if not NumpyEncoder._nan_warning_logged: |
| logger.warning(f"NumpyEncoder: Converting non-finite value ({f}) to 0.0. " |
| "This may indicate numerical instability in LRP computation.") |
| NumpyEncoder._nan_warning_logged = True |
| return 0.0 |
| return f |
| if isinstance(obj, np.ndarray): |
| return obj.tolist() |
| return super(NumpyEncoder, self).default(obj) |
|
|
| |
| PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) |
| sys.path.insert(0, PROJECT_ROOT) |
|
|
| from backend.models import ModelManager |
| from backend.core import AttributionEngine |
| from backend.circuit import CircuitAnalyzer |
| from backend.error_token_location import ErrorTokenLocator |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import RedirectResponse, StreamingResponse |
| from huggingface_hub import list_models, list_repo_refs |
|
|
| app = FastAPI(title="NeuralPostmortem - Evaluation Backend (Attribution Comparison & Perturbation)") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| frontend_path = os.path.join(PROJECT_ROOT, 'frontend') |
| if os.path.exists(frontend_path): |
| app.mount("/ui", StaticFiles(directory=frontend_path), name="ui") |
|
|
| @app.get("/") |
| async def read_root(): |
| return RedirectResponse(url="/ui/index.html") |
|
|
| |
| model_manager = ModelManager() |
| attribution_engine = None |
| error_token_locator = None |
|
|
| |
| CACHED_CONNECTION_DATA = { |
| "config_hash": None, |
| "data": None |
| } |
|
|
| def get_config_hash(bp_config, layers): |
| try: |
| |
| return json.dumps({ |
| "bp": bp_config, |
| "layers": sorted(layers) |
| }, sort_keys=True) |
| except: |
| return None |
|
|
| def unescape_string(text: str) -> str: |
| """ |
| Safely unescape string with escape sequences like \\n, \\t, \\r, etc. |
| |
| Args: |
| text: Input text that may contain escape sequences |
| |
| Returns: |
| Text with escape sequences converted to actual characters |
| """ |
| if not text: |
| return text |
|
|
| try: |
| |
| |
| return text.encode('utf-8').decode('unicode_escape') |
| except Exception as e: |
| |
| logging.warning(f"unicode_escape failed, using manual replacement: {e}") |
| result = text |
| result = result.replace('\\n', '\n') |
| result = result.replace('\\t', '\t') |
| result = result.replace('\\r', '\r') |
| result = result.replace('\\"', '"') |
| result = result.replace("\\'", "'") |
| result = result.replace('\\\\', '\\') |
| return result |
|
|
|
|
| |
| class LoadModelRequest(BaseModel): |
| model_path: str = "Qwen/Qwen3-0.6B" |
| quantization_4bit: bool = False |
| dtype: str = "float16" |
| revision: Optional[str] = None |
| |
|
|
| class ComputeLogitsRequest(BaseModel): |
| prompt: str |
| is_append_bos: bool = True |
| topk: int = 10 |
| extra_token_ids: Optional[List[int]] = None |
| extra_token_strs: Optional[List[str]] = None |
| capture_mid: bool = False |
|
|
| class BackpropConfig(BaseModel): |
| mode: str = "max_logit" |
| strategy: Optional[str] = "by_topk_avg" |
| ref_token_id: Optional[int] = None |
| contrast_rank: Optional[int] = 2 |
| k: Optional[int] = 10 |
| node_threshold: Optional[float] = 0.01 |
| target_token_id: Optional[int] = None |
|
|
| class ComputeCircuitRequest(BaseModel): |
| |
| backprop_config: BackpropConfig |
|
|
| |
| layers: List[int] |
|
|
| |
| pruning_mode: str = "by_per_layer_cum_mass_percentile" |
| top_p: float = 0.9 |
| edge_threshold: float = 0.01 |
|
|
|
|
| class ComputeInputAttributionRequest(BaseModel): |
| target_token_id: int |
| contrast_token_id: Optional[int] = None |
| backprop_config: BackpropConfig |
|
|
| class GenerateRequest(BaseModel): |
| prompt: str |
| max_new_tokens: int = 30 |
| append_token_id: Optional[int] = None |
|
|
| class LocateErrorTokenRequest(BaseModel): |
| prompt: str |
| completion: str |
| ground_truth: Optional[str] = None |
| validators: Optional[List[str]] = None |
| use_llm: bool = True |
| manual_chunks: Optional[List[str]] = None |
|
|
| class EnableLRPRequest(BaseModel): |
| lrp_rule: str = "Attn-LRP" |
| capture_mid: bool = False |
|
|
| class ComputePerturbationRequest(BaseModel): |
| attribution_scores: List[float] |
| k_values: List[int] = [1, 3, 5, 10] |
| target_token_id: int |
|
|
| class ComputePerturbationManualRequest(BaseModel): |
| perturb_indices: List[int] |
| target_token_id: int |
|
|
| @app.get("/api/list_hf_models") |
| async def list_hf_models(series: str = "Qwen2"): |
| """ |
| List models from HuggingFace Hub filtered by series/author. |
| """ |
| try: |
| if series.lower() == "qwen2": |
| models = list(list_models(author="Qwen", search="Qwen2", filter="text-generation", sort="downloads", direction=-1, limit=50)) |
| return {"models": [m.id for m in models]} |
|
|
| elif series.lower() == "qwen3": |
| models = list(list_models(author="Qwen", search="Qwen3", filter="text-generation", sort="downloads", direction=-1, limit=50)) |
| return {"models": [m.id for m in models]} |
|
|
| elif series.lower() == "olmo3": |
| models = list(list_models(author="allenai", search="Olmo-3", filter="text-generation", sort="downloads", direction=-1, limit=50)) |
| return {"models": [m.id for m in models]} |
|
|
| elif series.lower() == "olmo": |
| models = list(list_models(author="allenai", search="OLMo", filter="text-generation", sort="downloads", direction=-1, limit=50)) |
| return {"models": [m.id for m in models]} |
|
|
| elif series.lower() == "qwen": |
| models = list(list_models(author="Qwen", filter="text-generation", sort="downloads", direction=-1, limit=50)) |
| return {"models": [m.id for m in models]} |
|
|
| |
| models = list(list_models(search=series, filter="text-generation", sort="downloads", direction=-1, limit=20)) |
| return {"models": [m.id for m in models]} |
|
|
| except Exception as e: |
| print(f"Error listing models: {e}") |
| |
| if series.lower() == "qwen2": |
| return {"models": ["Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2-0.5B", "Qwen/Qwen2-1.5B", "Qwen/Qwen2-7B"]} |
| elif series.lower() == "qwen3": |
| return {"models": ["Qwen/Qwen3-0.6B"]} |
| elif series.lower() == "qwen": |
| return {"models": ["Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen3-0.6B"]} |
| elif series.lower() == "olmo3": |
| return {"models": ["allenai/Olmo-3-7B-Think"]} |
| elif series.lower() == "olmo": |
| return {"models": ["allenai/OLMo-7B", "allenai/OLMo-1B-0724", "allenai/Olmo-3-7B-Think"]} |
| return {"models": [], "error": str(e)} |
|
|
| @app.get("/api/list_model_revisions") |
| async def list_model_revisions(model_id: str): |
| """ |
| List git branches/refs for a model. |
| """ |
| try: |
| refs = list_repo_refs(model_id) |
| branches = [b.name for b in refs.branches] |
| tags = [t.name for t in refs.tags] |
| return {"branches": branches, "tags": tags} |
| except Exception as e: |
| print(f"Error listing revisions for {model_id}: {e}") |
| return {"branches": [], "tags": [], "error": str(e)} |
|
|
| @app.post("/api/cleanup") |
| async def cleanup_memory(): |
| global attribution_engine |
| if attribution_engine: |
| attribution_engine.reset() |
| else: |
| |
| torch.cuda.empty_cache() |
|
|
| import gc |
| gc.collect() |
|
|
| return {"status": "success", "message": "Memory cleanup complete"} |
|
|
| @app.post("/api/generate") |
| async def generate_continuation(request: GenerateRequest): |
| if not model_manager.model: |
| raise HTTPException(status_code=400, detail="Model not loaded") |
|
|
| tokenizer = model_manager.tokenizer |
| model = model_manager.model |
| device = model_manager.device |
|
|
| try: |
| |
| |
| |
| prompt = request.prompt |
|
|
| |
| was_training = model.training |
| model.eval() |
|
|
| |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
|
| |
| if request.append_token_id is not None: |
| token_tensor = torch.tensor([[request.append_token_id]], device=device) |
| input_ids = torch.cat([input_ids, token_tensor], dim=1) |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids, |
| max_new_tokens=request.max_new_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
|
|
| new_token_ids = output_ids[0][input_ids.shape[1]:] |
| generated_text = tokenizer.decode(new_token_ids, skip_special_tokens=False) |
|
|
| |
| if was_training: |
| model.train() |
|
|
| return {"generated_text": generated_text} |
|
|
| except Exception as e: |
| if model_manager.model and was_training: |
| model_manager.model.train() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/locate_err_token") |
| async def locate_error_token_endpoint(request: LocateErrorTokenRequest): |
| """ |
| Locate the error token in a completion using multiple LLM validators |
| """ |
| global error_token_locator |
|
|
| if not error_token_locator: |
| raise HTTPException(status_code=400, detail="Model not loaded. Please call /api/load_model first.") |
|
|
| try: |
| |
| prompt = request.prompt |
| completion = request.completion |
| ground_truth = request.ground_truth if request.ground_truth else None |
|
|
| |
| result = error_token_locator.locate_error_token( |
| prompt=prompt, |
| completion=completion, |
| ground_truth=ground_truth, |
| validators=request.validators, |
| use_llm=request.use_llm, |
| manual_chunks=request.manual_chunks |
| ) |
|
|
| if result["status"] == "error": |
| raise HTTPException(status_code=500, detail=result.get("message", "Unknown error")) |
|
|
| return { |
| "status": "success", |
| "truncated_text": result["truncated_text"], |
| "explanation": result["explanation"], |
| "error_token_index": result.get("error_token_index", -1), |
| "vote_details": result.get("vote_details", {}) |
| } |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/load_model") |
| async def load_model(request: LoadModelRequest): |
| global attribution_engine |
| global error_token_locator |
| try: |
| |
| model_name = model_manager.load_model( |
| request.model_path, |
| request.quantization_4bit, |
| dtype=request.dtype, |
| revision=request.revision, |
| lrp_rule=None |
| ) |
| attribution_engine = AttributionEngine(model_manager) |
| error_token_locator = ErrorTokenLocator(model_manager.model, model_manager.tokenizer) |
|
|
| |
| n_layers = 28 |
| try: |
| |
| if hasattr(model_manager.model, 'config'): |
| n_layers = getattr(model_manager.model.config, 'num_hidden_layers', 28) |
| except: |
| pass |
|
|
| |
| vocab_size = len(model_manager.tokenizer) |
|
|
| return { |
| "status": "success", |
| "message": f"Model {model_name} loaded successfully", |
| "num_layers": n_layers, |
| "vocab_size": vocab_size |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/enable_lrp") |
| async def enable_lrp(request: EnableLRPRequest): |
| """ |
| Enable LRP functionality on the loaded model. |
| This should be called before computing attribution or circuits. |
| For "Gradient" rule, the model is loaded WITHOUT LRP patches (vanilla gradient). |
| """ |
| if not model_manager.model: |
| raise HTTPException(status_code=400, detail="Model not loaded. Please call /api/load_model first.") |
|
|
| try: |
| |
| lrp_rule_for_model = None if request.lrp_rule == "Gradient" else request.lrp_rule |
|
|
| |
| model_name = model_manager.load_model( |
| model_path=model_manager.current_model_path, |
| quantization_4bit=model_manager.current_quantization, |
| dtype=model_manager.current_dtype, |
| revision=model_manager.current_revision, |
| lrp_rule=lrp_rule_for_model |
| ) |
|
|
| |
| global attribution_engine |
| attribution_engine = AttributionEngine(model_manager) |
|
|
| return { |
| "status": "success", |
| "message": f"Attribution method ({request.lrp_rule}) enabled successfully", |
| "lrp_rule": request.lrp_rule, |
| "capture_mid": request.capture_mid |
| } |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/compute_logits") |
| async def compute_logits(request: ComputeLogitsRequest): |
| global attribution_engine |
| if not attribution_engine: |
| raise HTTPException(status_code=400, detail="Model not loaded. Please call /api/load_model first.") |
|
|
| try: |
| |
| CACHED_CONNECTION_DATA["config_hash"] = None |
| CACHED_CONNECTION_DATA["data"] = None |
|
|
| |
| prompt = request.prompt |
|
|
| topk_data, _, input_tokens = attribution_engine.compute_logits( |
| prompt=prompt, |
| is_append_bos=request.is_append_bos, |
| topk=request.topk, |
| extra_token_ids=request.extra_token_ids, |
| extra_token_strs=request.extra_token_strs, |
| capture_mid=request.capture_mid |
| ) |
|
|
| |
| token_objs = [{"token_str": t, "token_id": i} for i, t in enumerate(input_tokens)] |
|
|
| return {"data": topk_data, "tokens": token_objs} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/compute_input_attribution") |
| async def compute_input_attribution_endpoint(request: ComputeInputAttributionRequest): |
| global attribution_engine |
| if not attribution_engine: |
| raise HTTPException(status_code=400, detail="Model not loaded.") |
|
|
| |
| if not model_manager.current_lrp_rule: |
| logger.info("LRP not enabled yet - auto-enabling with default rule 'Attn-LRP'...") |
| try: |
| model_name = model_manager.load_model( |
| model_path=model_manager.current_model_path, |
| quantization_4bit=model_manager.current_quantization, |
| dtype=model_manager.current_dtype, |
| revision=model_manager.current_revision, |
| lrp_rule="Attn-LRP" |
| ) |
| attribution_engine = AttributionEngine(model_manager) |
| logger.info(f"Auto-enabled LRP with Attn-LRP rule on {model_name}") |
| except Exception as e: |
| logger.error(f"Failed to auto-enable LRP: {e}") |
| raise HTTPException( |
| status_code=400, |
| detail="LRP not enabled and auto-enable failed. Please call /api/enable_lrp before computing attribution." |
| ) |
|
|
| try: |
| if attribution_engine.outputs is None: |
| raise HTTPException(status_code=400, detail="No forward pass found. Run compute_logits first.") |
|
|
| |
| bp_config = request.backprop_config.dict() |
| bp_config["target_token_id"] = request.target_token_id |
|
|
| relevance = attribution_engine.compute_input_attribution(bp_config) |
| return {"relevance": relevance} |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/compute_input_attribution_gradient") |
| async def compute_input_attribution_gradient_endpoint(request: ComputeInputAttributionRequest): |
| """ |
| Compute input attribution using vanilla gradient method (Input * Gradient). |
| Does NOT require LRP to be enabled - uses standard PyTorch autograd. |
| """ |
| global attribution_engine |
| if not attribution_engine: |
| raise HTTPException(status_code=400, detail="Model not loaded.") |
|
|
| try: |
| if attribution_engine.outputs is None: |
| raise HTTPException(status_code=400, detail="No forward pass found. Run compute_logits first.") |
|
|
| |
| bp_config = request.backprop_config.dict() |
| bp_config["target_token_id"] = request.target_token_id |
|
|
| relevance = attribution_engine.compute_input_attribution_gradient(bp_config) |
| return {"relevance": relevance} |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/compute_perturbation") |
| async def compute_perturbation_endpoint(request: ComputePerturbationRequest): |
| """ |
| Evaluate attribution quality by perturbing top-attributed tokens. |
| Zero out top-k most attributed tokens and check if the error is fixed. |
| """ |
| global attribution_engine |
| if not attribution_engine: |
| raise HTTPException(status_code=400, detail="Model not loaded.") |
|
|
| try: |
| if attribution_engine.input_ids is None or attribution_engine.input_embeddings is None: |
| raise HTTPException(status_code=400, detail="No forward pass found. Run compute_logits first.") |
|
|
| results = attribution_engine.compute_perturbation_eval( |
| attribution_scores=request.attribution_scores, |
| k_values=request.k_values, |
| target_token_id=request.target_token_id |
| ) |
| return {"results": results} |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/compute_perturbation_manual") |
| async def compute_perturbation_manual_endpoint(request: ComputePerturbationManualRequest): |
| """ |
| Evaluate attribution by perturbing manually selected token positions. |
| Zero out the specified token embeddings and check if the error is fixed. |
| """ |
| global attribution_engine |
| if not attribution_engine: |
| raise HTTPException(status_code=400, detail="Model not loaded.") |
|
|
| try: |
| if attribution_engine.input_ids is None or attribution_engine.input_embeddings is None: |
| raise HTTPException(status_code=400, detail="No forward pass found. Run compute_logits first.") |
|
|
| result = attribution_engine.compute_perturbation_manual( |
| perturb_indices=request.perturb_indices, |
| target_token_id=request.target_token_id |
| ) |
| return {"result": result} |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/compute_circuit") |
| async def compute_circuit(request: ComputeCircuitRequest): |
| global attribution_engine |
| if not attribution_engine: |
| raise HTTPException(status_code=400, detail="Model not loaded.") |
|
|
| |
| if not model_manager.current_lrp_rule: |
| logger.info("LRP not enabled yet - auto-enabling with default rule 'Attn-LRP' for circuit analysis...") |
| try: |
| model_name = model_manager.load_model( |
| model_path=model_manager.current_model_path, |
| quantization_4bit=model_manager.current_quantization, |
| dtype=model_manager.current_dtype, |
| revision=model_manager.current_revision, |
| lrp_rule="Attn-LRP" |
| ) |
| attribution_engine = AttributionEngine(model_manager) |
| logger.info(f"Auto-enabled LRP with Attn-LRP rule on {model_name}") |
| except Exception as e: |
| logger.error(f"Failed to auto-enable LRP: {e}") |
| raise HTTPException( |
| status_code=400, |
| detail="LRP not enabled and auto-enable failed. Please call /api/enable_lrp before computing circuits." |
| ) |
|
|
| if attribution_engine.outputs is None: |
| raise HTTPException(status_code=400, detail="No forward pass found. Run compute_logits first.") |
|
|
| async def generate_response(): |
| try: |
| |
| yield json.dumps({"type": "progress", "msg": "Initiating Backward Pass...", "percent": 0}) + "\n" |
|
|
| |
| analyzer = CircuitAnalyzer(attribution_engine) |
|
|
| bp_config = request.backprop_config.dict() |
|
|
| |
| |
| current_hash = get_config_hash(bp_config, request.layers) |
| connection_data = None |
|
|
| if CACHED_CONNECTION_DATA["config_hash"] == current_hash and CACHED_CONNECTION_DATA["data"] is not None: |
| yield json.dumps({"type": "progress", "msg": "Using Cached Matrices (Fast)...", "percent": 50}) + "\n" |
| connection_data = CACHED_CONNECTION_DATA["data"] |
| else: |
| yield json.dumps({"type": "progress", "msg": "Computing Circuit (This may take a moment)...", "percent": 20}) + "\n" |
| |
| connection_data = analyzer.compute_connection_matrices(bp_config, sorted(request.layers)) |
|
|
| |
| CACHED_CONNECTION_DATA["config_hash"] = current_hash |
| CACHED_CONNECTION_DATA["data"] = connection_data |
|
|
| yield json.dumps({"type": "progress", "msg": "Pruning & Building Graph...", "percent": 80}) + "\n" |
|
|
| G, pruning_details = analyzer.build_graph_from_matrices( |
| connection_data, |
| edge_rel_threshold=request.edge_threshold, |
| pruning_mode=request.pruning_mode, |
| top_p=request.top_p |
| ) |
|
|
| yield json.dumps({"type": "progress", "msg": "Graph Constructed. Serializing...", "percent": 90}) + "\n" |
|
|
| |
| graph_data = nx.node_link_data(G) |
|
|
| yield json.dumps({ |
| "type": "graph_data", |
| "graph": graph_data, |
| "pruning_details": pruning_details |
| }, cls=NumpyEncoder) + "\n" |
|
|
| yield json.dumps({"type": "progress", "msg": "Complete!", "percent": 100}) + "\n" |
| yield json.dumps({"type": "complete"}) + "\n" |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| yield json.dumps({"type": "error", "msg": str(e)}) + "\n" |
|
|
| return StreamingResponse(generate_response(), media_type="application/x-ndjson") |
|
|
| |
| DATA_DIR = os.path.join(PROJECT_ROOT, "data") |
|
|
| @app.get("/api/datasets") |
| async def get_datasets(): |
| """Scan data/ directory for available datasets (subdirectories).""" |
| datasets = [] |
| if os.path.isdir(DATA_DIR): |
| for name in sorted(os.listdir(DATA_DIR)): |
| full_path = os.path.join(DATA_DIR, name) |
| if os.path.isdir(full_path): |
| datasets.append(name) |
| return {"datasets": datasets} |
|
|
| @app.get("/api/traces/{dataset}") |
| async def get_traces(dataset: str): |
| """List trace files (JSON) in a dataset directory.""" |
| dataset_dir = os.path.join(DATA_DIR, dataset) |
| if not os.path.isdir(dataset_dir): |
| raise HTTPException(status_code=404, detail=f"Dataset '{dataset}' not found") |
|
|
| traces = [] |
| for name in sorted(os.listdir(dataset_dir)): |
| if name.endswith(".json"): |
| traces.append(name) |
| return {"traces": traces} |
|
|
| @app.get("/api/trace_details/{dataset}/{trace_file}") |
| async def get_trace_details(dataset: str, trace_file: str): |
| """Load and return trace file details.""" |
| file_path = os.path.join(DATA_DIR, dataset, trace_file) |
| if not os.path.isfile(file_path): |
| raise HTTPException(status_code=404, detail=f"Trace file '{trace_file}' not found in dataset '{dataset}'") |
|
|
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| trace_data = json.load(f) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Failed to parse trace file: {str(e)}") |
|
|
| |
| metadata = trace_data.get("metadata", {}) |
| result = { |
| "model_path": metadata.get("model", ""), |
| "dtype": str(metadata.get("dtype", "float16")).replace("torch.", ""), |
| "quantization": False, |
| "prompt": trace_data.get("prompt", ""), |
| "raw_prompt": trace_data.get("prompt", ""), |
| "completion": trace_data.get("completion", ""), |
| "ground_truth": trace_data.get("ground_truth", ""), |
| "eval_result": trace_data.get("eval_result", None), |
| "topk_token_explore": trace_data.get("topk_token_explore", []), |
| } |
|
|
| |
| other_candidates = {} |
| for key in trace_data: |
| if key.startswith("topk_token_explore_") and key != "topk_token_explore": |
| suffix = key.replace("topk_token_explore_", "") |
| other_candidates[suffix] = trace_data[key] |
| if other_candidates: |
| result["other_candidates"] = other_candidates |
|
|
| return result |
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", 7860)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|