from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware import os import shutil import tempfile import xml.etree.ElementTree as ET from Decipher.pt_crypto import decrypt_pkt from Decipher.extract import extract_structured_data from Decipher.generator import generate_network from Decipher.models import NetworkTopology from pydantic import Field, BaseModel class PromptRequest(BaseModel): prompt: str = Field(..., description="The description of the network to generate", json_schema_extra={"example": "A network with 1 router and 2 PCs"}) app = FastAPI(title="TopoScan API", description="API for decrypting and extracting data from Cisco Packet Tracer files") # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): return {"message": "TopoScan API is running"} @app.post("/analyze-pkt") async def analyze_pkt(file: UploadFile = File(...)): if not file.filename.endswith(".pkt") and not file.filename.endswith(".pka"): raise HTTPException(status_code=400, detail="Only .pkt or .pka files are supported") try: # Read the uploaded file content content = await file.read() # Decrypt the file try: xml_data = decrypt_pkt(content) except Exception as e: raise HTTPException(status_code=500, detail=f"Decryption failed: {str(e)}") # Create a temporary file to store the XML for extraction with tempfile.NamedTemporaryFile(delete=False, suffix=".xml") as temp_xml: temp_xml.write(xml_data) temp_xml_path = temp_xml.name try: # Extract structured data using existing logic structured_data = extract_structured_data(temp_xml_path) if structured_data is None: raise HTTPException(status_code=500, detail="Failed to extract structured data from decrypted XML") return structured_data finally: # Clean up the temporary file if os.path.exists(temp_xml_path): os.remove(temp_xml_path) except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") @app.post("/generate-network", response_model=NetworkTopology) async def api_generate_network(request: PromptRequest): try: topology = generate_network(request.prompt) return topology except Exception as e: raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") if __name__ == "__main__": import uvicorn # Hugging Face Spaces and other cloud providers often use the PORT environment variable port = int(os.environ.get("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)