Rafs-an09002's picture
Update app.py
82212d4 verified
import os
import chess
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import hf_hub_download
from pydantic import BaseModel
app = FastAPI()
class ChessRequest(BaseModel):
fen: str
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
MODEL_REPO = "GambitFlow/Synapse-Edge"
MODEL_FILENAME = "v1/synapse_edge_v1.onnx"
try:
print("📥 Downloading flagship model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
session = ort.InferenceSession(model_path)
print("✅ Synapse-Edge v1 Loaded.")
except Exception as e:
print(f"❌ Model Load Error: {e}")
# [CRITICAL FIX]: ট্রেনিং কোডের সাথে ১০০% ম্যাচ করা টেনসর লজিক
def get_tensor(fen):
# ট্রেনিং কোড অনুযায়ী ১১৯ চ্যানেল, কিন্তু ডেটা শুধু ১২টিতে
tensor = np.zeros((1, 119, 8, 8), dtype=np.float32)
position = fen.split(' ')[0]
# পিস ম্যাপ (ট্রেনিং কোড অনুযায়ী)
piece_to_channel = {'P':0, 'N':1, 'B':2, 'R':3, 'Q':4, 'K':5, 'p':6, 'n':7, 'b':8, 'r':9, 'q':10, 'k':11}
rank = 0
file = 0
for char in position:
if char == '/':
rank += 1
file = 0
elif char.isdigit():
file += int(char)
elif char in piece_to_channel:
if rank < 8 and file < 8:
tensor[0, piece_to_channel[char], rank, file] = 1.0
file += 1
# বাকি ১০৭টি চ্যানেল সব জিরো থাকবে (ট্রেনিংয়ের সময় যেমন ছিল)
return tensor
def predict(fen):
tensor = get_tensor(fen)
return session.run(None, {"input": tensor})
@app.post("/get_move")
async def get_move(req: ChessRequest):
try:
board = chess.Board(req.fen)
if board.is_game_over():
return JSONResponse({"error": "Game over"}, status_code=400)
# ১. বর্তমান পজিশনে ইনফারেন্স
policy, value, tactical, phase = predict(req.fen)
legal_moves = list(board.legal_moves)
move_candidates = []
# ২. ভ্যালু হেড দিয়ে চেক করা (মডেলের নিজের নলেজ অনুযায়ী)
for move in legal_moves:
board.push(move)
# কালোর চাল হলে সাদার পয়েন্ট নেগেটিভ হওয়া মানে কালো জিতছে
_, next_v, _, _ = predict(board.fen())
board.pop()
v_score = float(next_v[0][0])
# সাদার চাল হলে বেশি স্কোর ভালো, কালোর চাল হলে কম স্কোর ভালো
actual_score = v_score if board.turn == chess.WHITE else -v_score
# ৩. সিম্পল পলিসি প্রায়োরিটি (ইন্ডেক্সিং মিসম্যাচ এড়াতে আমরা ভ্যালুকে গুরুত্ব দিব বেশি)
move_candidates.append((move, actual_score))
# সবচেয়ে ভালো ভ্যালুর চালটি বাছাই
move_candidates.sort(key=lambda x: x[1], reverse=True)
best_move = move_candidates[0][0]
# ৪. যদি সেরা চালেও কোনো ব্লান্ডার হওয়ার ভয় থাকে, তবে সেকেন্ড বেস্ট ট্রাই করা
# (v1 এর জন্য আমরা ১-লেভেল ভ্যালু সার্চকেই ফাইনাল রাখছি)
return {
"move": best_move.uci(),
"value": float(value[0][0]),
"tactical": float(tactical[0][0]),
"phase": int(np.argmax(phase[0]))
}
except Exception as e:
print(f"🔥 Server Error: {e}")
return JSONResponse({"error": str(e)}, status_code=400)
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)