idk / app.py
helloperson123's picture
Update app.py
bfc4c59 verified
# app.py
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
import re
# -------------------------------
# SETTINGS
# -------------------------------
MODEL_NAME = "TheDrummer/Gemmasutra-Mini-2B-v1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_TOKENS = 256
# FIX: make JSON preferred, not fragile
SYSTEM_PROMPT = """You are Acla, an AI assistant created by NC_1320.
You answer the user's question once and stop.
Do not write User:, AI:, or continue a conversation.
Prefer responding in valid JSON exactly like:
{"response":"your answer here"}
If JSON is not possible, respond with plain text only.
"""
# -------------------------------
# LOAD MODEL
# -------------------------------
print(f"Loading {MODEL_NAME} on {DEVICE}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE)
print("Model loaded!")
# -------------------------------
# CREATE API
# -------------------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/api/ask")
async def ask_ai(request: Request):
data = await request.json()
user_prompt = data.get("prompt", "").strip()
if not user_prompt:
return {"reply": "No prompt provided."}
# FIX: explicit answer anchor
full_prompt = SYSTEM_PROMPT + "\n\nUser input:\n" + user_prompt + "\n\nResponse:\n"
inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
outputs = model.generate(
**inputs,
max_new_tokens=MAX_TOKENS,
do_sample=False,
temperature=0.0,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove prompt echo
text = generated_text.split("Response:", 1)[-1].strip()
# Try JSON first
match = re.search(r"\{[\s\S]*?\}", text)
if match:
try:
parsed = json.loads(match.group(0))
reply = parsed.get("response", "").strip()
except Exception:
reply = ""
else:
reply = ""
# FIX: plain-text fallback
if not reply:
for stop in ["User:", "AI:", "Assistant:"]:
text = text.split(stop)[0]
reply = text.strip()
# FIX: never empty
if not reply:
reply = "I could not generate a response."
return {"reply": reply}
# -------------------------------
# RUN SERVER
# -------------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)