ai1 / app.py
Kleinpuki2's picture
Update app.py
51bb90e verified
import torch
from model import MiniTransformer, BPETokenizer
from fastapi import FastAPI, Request, Header, HTTPException
from huggingface_hub import hf_hub_download
import uvicorn
import os
app = FastAPI()
REPO_ID = "Kleinpuki2/madgamesai"
FILENAME = "madgames_gpt2_stable.pth"
API_KEY = "MG-ADMIN-1337"
model = None
tokenizer = BPETokenizer()
def load_model():
global model
try:
print(f"Lade Checkpoint von {REPO_ID}...")
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model = MiniTransformer.load(path, device='cpu')
print("Modell erfolgreich geladen und initialisiert!")
return True
except Exception as e:
print(f"Fehler beim Laden: {e}")
return False
is_loaded = load_model()
@app.get("/")
def root():
return {"status": "online", "loaded": is_loaded, "repo": REPO_ID}
@app.post("/predict")
async def predict(request: Request, x_api_key: str = Header(None)):
if x_api_key != API_KEY:
raise HTTPException(status_code=401)
if not is_loaded or model is None:
return {"response": "Fehler: Modell konnte nicht korrekt initialisiert werden."}
try:
data = await request.json()
prompt = data.get("prompt", "")
if not prompt: return {"response": ""}
print(f"Anfrage empfangen: {prompt[:50]}...")
# WICHTIG: Das exakte Format aus dem Training nachbauen!
formatted_prompt = f"User: {prompt}\nKI: "
tokens = tokenizer.encode(formatted_prompt)
ctx_len = model.ctx_len if hasattr(model, 'ctx_len') else 1024
tokens = tokens[-ctx_len:]
idx = torch.tensor([tokens], dtype=torch.long)
# Perfekte Settings für Code & Chat
out = model.generate(idx, max_new_tokens=250, temperature=0.2, top_k=10, repetition_penalty=1.0)
# Nur den neu generierten Teil extrahieren
generated_tokens = out[0, len(tokens):].tolist()
response = tokenizer.decode(generated_tokens)
# <|endoftext|> entfernen, falls es im Text auftaucht
response = response.replace("<|endoftext|>", "").strip()
# Prüfen, ob die KI anfängt, den nächsten "User:" Text zu generieren
if "User:" in response:
response = response.split("User:")[0]
final_text = response.strip()
if not final_text:
final_text = "Die KI hat noch keine klare Antwort gefunden. Trainiere sie noch ein wenig weiter (Ziel: Loss unter 1.0) oder versuche einen anderen Prompt!"
print(f"Antwort generiert: {final_text[:50]}...")
return {"response": final_text}
except Exception as e:
print(f"Fehler bei Vorhersage: {e}")
return {"response": f"Runtime Fehler: {str(e)}"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)