model_inference / app.py
MossaicMan's picture
Update app.py
4529c09 verified
#!/usr/bin/env python3
"""
NeuralMesh Backend - CID Metadata + Hugging Face Execution
Sepolia + Pinata + Hugging Face
"""
import os
import time
import logging
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from web3 import Web3
from eth_account import Account
import httpx
from sentence_transformers import SentenceTransformer
# ==========================================
# LOAD ENV
# ==========================================
load_dotenv()
CONFIG = {
"RPC_URL": "https://rpc.sepolia.org",
"CONTRACT_ADDRESS": "0xc76Bf13d48C61A68865aa16D91D2ECf86e7Fc773",
"PRIVATE_KEY": os.getenv("ORACLE_PRIVATE_KEY"),
"CHAIN_ID": 11155111,
"PINATA_GATEWAY": "https://gateway.pinata.cloud/ipfs/",
"HF_TOKEN": os.getenv("HF_TOKEN"),
"HOST": "0.0.0.0",
"PORT": 8000,
"CORS_ORIGINS": ["http://localhost:5173"],
}
if not CONFIG["PRIVATE_KEY"]:
raise ValueError("❌ ORACLE_PRIVATE_KEY not set")
if not CONFIG["HF_TOKEN"]:
raise ValueError("❌ HF_TOKEN not set")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(message)s"
)
logger = logging.getLogger("NeuralMesh")
# ==========================================
# CONTRACT ABI
# ==========================================
CONTRACT_ABI = [
{
"inputs": [{"name": "_modelId", "type": "uint256"}, {"name": "_user", "type": "address"}],
"name": "hasAccess",
"outputs": [{"name": "", "type": "bool"}],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [{"name": "_modelId", "type": "uint256"}],
"name": "models",
"outputs": [
{"name": "model_id", "type": "uint256"},
{"name": "model_cid", "type": "string"},
{"name": "model_creator", "type": "address"},
{"name": "model_price", "type": "uint256"},
{"name": "model_usageCount", "type": "uint256"},
{"name": "is_model_exists", "type": "bool"}
],
"stateMutability": "view",
"type": "function"
}
]
# ==========================================
# BLOCKCHAIN SERVICE
# ==========================================
class BlockchainService:
def __init__(self):
self.w3 = Web3(Web3.HTTPProvider(CONFIG["RPC_URL"]))
if not self.w3.is_connected():
raise ConnectionError("❌ Cannot connect to Sepolia RPC")
self.contract = self.w3.eth.contract(
address=Web3.to_checksum_address(CONFIG["CONTRACT_ADDRESS"]),
abi=CONTRACT_ABI
)
self.account = Account.from_key(CONFIG["PRIVATE_KEY"])
logger.info(f"🔗 Connected to Sepolia")
logger.info(f"🔐 Oracle: {self.account.address}")
def verify_access(self, model_id: int, user_address: str) -> bool:
try:
user = Web3.to_checksum_address(user_address)
has_access = self.contract.functions.hasAccess(model_id, user).call()
model = self.contract.functions.models(model_id).call()
return has_access and model[5]
except Exception as e:
logger.error(f"Access verification failed: {e}")
return False
def get_model_cid(self, model_id: int):
try:
model = self.contract.functions.models(model_id).call()
if model[5]:
return model[1]
return None
except Exception as e:
logger.error(f"CID fetch failed: {e}")
return None
# ==========================================
# MODEL SERVICE (HUGGING FACE EXECUTION)
# ==========================================
class ModelService:
def __init__(self):
self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
async def run_inference(self, model_repo: str, text: str):
return self.model.encode(text).tolist()
# ==========================================
# METADATA FETCH (IPFS)
# ==========================================
async def fetch_metadata(cid: str):
url = f"{CONFIG['PINATA_GATEWAY']}{cid}"
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
return response.json()
# ==========================================
# FASTAPI
# ==========================================
class InferenceRequest(BaseModel):
model_id: int
user_address: str
text: str
blockchain_service = None
model_service = ModelService()
@asynccontextmanager
async def lifespan(app: FastAPI):
global blockchain_service
logger.info("🚀 Starting NeuralMesh Backend...")
blockchain_service = BlockchainService()
yield
app = FastAPI(title="NeuralMesh", version="2.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=CONFIG["CORS_ORIGINS"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==========================================
# ROUTES
# ==========================================
@app.get("/")
def root():
return {
"service": "NeuralMesh",
"network": "Sepolia",
"contract": CONFIG["CONTRACT_ADDRESS"],
"oracle": blockchain_service.account.address if blockchain_service else None
}
@app.get("/health")
def health():
return {
"status": "healthy",
"blockchain_connected": blockchain_service.w3.is_connected(),
"oracle": blockchain_service.account.address,
"hf_token_configured": bool(CONFIG["HF_TOKEN"])
}
@app.post("/embed")
async def embed(req: InferenceRequest):
start = time.time()
# 1️⃣ Verify ownership
if not blockchain_service.verify_access(req.model_id, req.user_address):
raise HTTPException(status_code=403, detail="Purchase required")
print("come here ")
# 2️⃣ Get CID from contract
cid = blockchain_service.get_model_cid(req.model_id)
print("the cid is ", cid)
if not cid:
raise HTTPException(status_code=404, detail="Model not found")
# 3️⃣ Fetch metadata from IPFS
metadata = await fetch_metadata(cid)
print("metadata is ",metadata)
# hf_repo = metadata.get("huggingFaceUrl")
hf_repo = "sentence-transformers/all-MiniLM-L6-v2"
print("hf repo is ",hf_repo)
if not hf_repo:
raise HTTPException(status_code=500, detail="Invalid metadata")
# 4️⃣ Call Hugging Face
output = await model_service.run_inference(hf_repo, req.text)
print("the output is ",output)
return {
"model_id": req.model_id,
"output": output,
"processing_time_ms": round((time.time() - start) * 1000, 2)
}
# ==========================================
# RUN SERVER
# ==========================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=CONFIG["HOST"], port=CONFIG["PORT"])