topo / main.py
sae8d's picture
Upload 16 files
55e2289 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import os
import shutil
import tempfile
import xml.etree.ElementTree as ET
from Decipher.pt_crypto import decrypt_pkt
from Decipher.extract import extract_structured_data
from Decipher.generator import generate_network
from Decipher.models import NetworkTopology
from pydantic import Field, BaseModel
class PromptRequest(BaseModel):
prompt: str = Field(..., description="The description of the network to generate", json_schema_extra={"example": "A network with 1 router and 2 PCs"})
app = FastAPI(title="TopoScan API", description="API for decrypting and extracting data from Cisco Packet Tracer files")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "TopoScan API is running"}
@app.post("/analyze-pkt")
async def analyze_pkt(file: UploadFile = File(...)):
if not file.filename.endswith(".pkt") and not file.filename.endswith(".pka"):
raise HTTPException(status_code=400, detail="Only .pkt or .pka files are supported")
try:
# Read the uploaded file content
content = await file.read()
# Decrypt the file
try:
xml_data = decrypt_pkt(content)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Decryption failed: {str(e)}")
# Create a temporary file to store the XML for extraction
with tempfile.NamedTemporaryFile(delete=False, suffix=".xml") as temp_xml:
temp_xml.write(xml_data)
temp_xml_path = temp_xml.name
try:
# Extract structured data using existing logic
structured_data = extract_structured_data(temp_xml_path)
if structured_data is None:
raise HTTPException(status_code=500, detail="Failed to extract structured data from decrypted XML")
return structured_data
finally:
# Clean up the temporary file
if os.path.exists(temp_xml_path):
os.remove(temp_xml_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
@app.post("/generate-network", response_model=NetworkTopology)
async def api_generate_network(request: PromptRequest):
try:
topology = generate_network(request.prompt)
return topology
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
# Hugging Face Spaces and other cloud providers often use the PORT environment variable
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)