# app.py from transformers import AutoModelForCausalLM, AutoTokenizer import torch from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware import uvicorn import json import re # ------------------------------- # SETTINGS # ------------------------------- MODEL_NAME = "TheDrummer/Gemmasutra-Mini-2B-v1" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_TOKENS = 256 # FIX: make JSON preferred, not fragile SYSTEM_PROMPT = """You are Acla, an AI assistant created by NC_1320. You answer the user's question once and stop. Do not write User:, AI:, or continue a conversation. Prefer responding in valid JSON exactly like: {"response":"your answer here"} If JSON is not possible, respond with plain text only. """ # ------------------------------- # LOAD MODEL # ------------------------------- print(f"Loading {MODEL_NAME} on {DEVICE}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 ).to(DEVICE) print("Model loaded!") # ------------------------------- # CREATE API # ------------------------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.post("/api/ask") async def ask_ai(request: Request): data = await request.json() user_prompt = data.get("prompt", "").strip() if not user_prompt: return {"reply": "No prompt provided."} # FIX: explicit answer anchor full_prompt = SYSTEM_PROMPT + "\n\nUser input:\n" + user_prompt + "\n\nResponse:\n" inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE) outputs = model.generate( **inputs, max_new_tokens=MAX_TOKENS, do_sample=False, temperature=0.0, repetition_penalty=1.1, eos_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove prompt echo text = generated_text.split("Response:", 1)[-1].strip() # Try JSON first match = re.search(r"\{[\s\S]*?\}", text) if match: try: parsed = json.loads(match.group(0)) reply = parsed.get("response", "").strip() except Exception: reply = "" else: reply = "" # FIX: plain-text fallback if not reply: for stop in ["User:", "AI:", "Assistant:"]: text = text.split(stop)[0] reply = text.strip() # FIX: never empty if not reply: reply = "I could not generate a response." return {"reply": reply} # ------------------------------- # RUN SERVER # ------------------------------- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)