File size: 2,829 Bytes
ff0a624
cb656b4
ff0a624
 
 
 
a9ff87c
a450cd9
ca791bd
ff0a624
 
 
cb656b4
ff0a624
bfc4c59
 
 
 
 
 
 
 
 
 
 
a9ff87c
ca791bd
ff0a624
 
 
 
 
 
 
 
 
a9ff87c
806c29f
ff0a624
 
 
 
cb656b4
ff0a624
 
a9ff87c
ff0a624
 
 
6b1c5b7
ff0a624
 
 
a9ff87c
 
ff0a624
a9ff87c
bfc4c59
 
a9ff87c
ff0a624
a9ff87c
ff0a624
 
a9ff87c
ff0a624
 
 
 
54536f6
a9ff87c
ff0a624
a9ff87c
bfc4c59
 
 
 
 
a450cd9
 
 
bfc4c59
a450cd9
 
 
 
a9ff87c
bfc4c59
 
 
 
 
 
 
a450cd9
bfc4c59
fec64bd
a450cd9
6b1c5b7
ff0a624
 
 
6b1c5b7
a450cd9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# app.py
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
import re

# -------------------------------
# SETTINGS
# -------------------------------
MODEL_NAME = "TheDrummer/Gemmasutra-Mini-2B-v1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_TOKENS = 256

# FIX: make JSON preferred, not fragile
SYSTEM_PROMPT = """You are Acla, an AI assistant created by NC_1320.
You answer the user's question once and stop.
Do not write User:, AI:, or continue a conversation.

Prefer responding in valid JSON exactly like:
{"response":"your answer here"}

If JSON is not possible, respond with plain text only.
"""

# -------------------------------
# LOAD MODEL
# -------------------------------
print(f"Loading {MODEL_NAME} on {DEVICE}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE)
print("Model loaded!")

# -------------------------------
# CREATE API
# -------------------------------
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/api/ask")
async def ask_ai(request: Request):
    data = await request.json()
    user_prompt = data.get("prompt", "").strip()
    if not user_prompt:
        return {"reply": "No prompt provided."}

    # FIX: explicit answer anchor
    full_prompt = SYSTEM_PROMPT + "\n\nUser input:\n" + user_prompt + "\n\nResponse:\n"

    inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)

    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_TOKENS,
        do_sample=False,
        temperature=0.0,
        repetition_penalty=1.1,
        eos_token_id=tokenizer.eos_token_id
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Remove prompt echo
    text = generated_text.split("Response:", 1)[-1].strip()

    # Try JSON first
    match = re.search(r"\{[\s\S]*?\}", text)
    if match:
        try:
            parsed = json.loads(match.group(0))
            reply = parsed.get("response", "").strip()
        except Exception:
            reply = ""
    else:
        reply = ""

    # FIX: plain-text fallback
    if not reply:
        for stop in ["User:", "AI:", "Assistant:"]:
            text = text.split(stop)[0]
        reply = text.strip()

    # FIX: never empty
    if not reply:
        reply = "I could not generate a response."

    return {"reply": reply}

# -------------------------------
# RUN SERVER
# -------------------------------
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)