Valtry's picture
Update app.py
c0d5385 verified
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from supabase import create_client
import os, json, uvicorn, threading
from contextlib import asynccontextmanager
from fastapi import UploadFile, File, Form
from PIL import Image
import torch
import base64
from io import BytesIO
from transformers import BlipProcessor, BlipForConditionalGeneration
# =========================
# CONFIG
# =========================
HF_TOKEN = os.getenv("HF_TOKEN")
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
models = {}
stop_flags = {}
vision_model = None
vision_processor = None
# =========================
# REQUEST
# =========================
class ChatRequest(BaseModel):
user_id: str
conversation_id: str
messages: list
temperature: float = 0.7
stream: bool = False
branch: bool = False
parent_id: str | None = None
# =========================
# SYSTEM PROMPTS
# =========================
LLAMA_PROMPT = """You are a friendly, capable, and reliable AI assistant.
- Answer clearly and accurately
- Use clean structure
- Avoid repetition
- Finish naturally
"""
# πŸ”₯ STRONG CODER PROMPT (ANTI-HALLUCINATION)
CODER_PROMPT = """You are a strict and precise programming assistant.
Rules:
- Answer ONLY the user's latest question
- DO NOT continue previous conversations
- DO NOT invent new questions
- DO NOT simulate dialogue (no "Human:" or "Assistant:")
- DO NOT repeat yourself
- If code is needed β†’ return clean code in ``` blocks
- If explanation is needed β†’ keep it short and relevant
- Stop immediately after completing the answer
"""
MODEL_PROMPTS = {
"llama": LLAMA_PROMPT,
"coder": CODER_PROMPT,
}
# =========================
# CLEAN OUTPUT
# =========================
def clean_output(text):
stop_words = [
"<|eot_id|>",
"<|end_of_text|>",
"<|eof|>",
"Human:",
"Assistant:",
"User:"
]
for w in stop_words:
if w in text:
text = text.split(w)[0]
return text.strip()
def process_image(image: Image.Image, user_prompt: str):
inputs = vision_processor(images=image, return_tensors="pt")
out = vision_model.generate(
**inputs,
max_new_tokens=50
)
caption = vision_processor.decode(out[0], skip_special_tokens=True)
# πŸ”₯ combine with user prompt (VERY IMPORTANT)
return f"{caption}. User focus: {user_prompt}"
# =========================
# CHAT STORAGE
# =========================
def get_messages(cid):
res = supabase.table("messages").select("role,content").eq("conversation_id", cid).order("created_at").execute()
return res.data or []
def save_message(cid, role, content, parent_id=None, branch_id=None):
supabase.table("messages").insert({
"conversation_id": cid,
"role": role,
"content": content,
"parent_id": parent_id,
"branch_id": branch_id
}).execute()
# =========================
# PROMPT BUILDER
# =========================
def build_prompt(messages, user_id, cid, model_name):
base = MODEL_PROMPTS[model_name]
# βœ… fetch history
history = get_messages(cid)
# =========================
# πŸ”₯ FILTER BAD HISTORY
# =========================
filtered_history = []
for msg in history:
try:
content = json.loads(msg["content"])
# ❌ skip image uploads (BLIP / user image)
if isinstance(content, dict) and content.get("type") == "image":
continue
# ❌ skip if marked from image model
if isinstance(content, dict) and content.get("model") == "user_upload":
continue
# βœ… keep only text messages
if isinstance(content, dict) and content.get("type") == "text":
filtered_history.append({
"role": msg["role"],
"content": content["data"]
})
else:
# fallback (old plain text messages)
filtered_history.append(msg)
except:
filtered_history.append(msg)
# limit history
filtered_history = filtered_history[-6:]
# =========================
# LLAMA FORMAT (UNCHANGED)
# =========================
if model_name == "llama":
prompt = "<|begin_of_text|>\n"
prompt += "<|start_header_id|>system<|end_header_id|>\n"
prompt += base + "\n<|eot_id|>\n"
for msg in (filtered_history + messages):
prompt += f"<|start_header_id|>{msg['role']}<|end_header_id|>\n{msg['content']}\n<|eot_id|>\n"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
return prompt
# =========================
# CODER (UNCHANGED LOGIC)
# =========================
last_user = ""
for m in reversed(messages):
if m["role"] == "user":
last_user = m["content"]
break
return f"""System:
{base}
User:
{last_user}
Assistant:
"""
# =========================
# MODEL LOADING
# =========================
def load_model(repo, file, optimized=False):
if optimized:
return Llama(
model_path=hf_hub_download(repo_id=repo, filename=file, token=HF_TOKEN, cache_dir="/data"),
n_ctx=2048,
n_threads=3,
n_batch=1024,
use_mmap=True,
use_mlock=True,
f16_kv=True,
verbose=False
)
return Llama(
model_path=hf_hub_download(repo_id=repo, filename=file, token=HF_TOKEN, cache_dir="/data"),
n_ctx=2048,
n_threads=4,
n_batch=512,
use_mmap=True,
use_mlock=True,
f16_kv=True,
verbose=False
)
def load_models():
models["llama"] = load_model(
"Valtry/llama3.2-3b-q4-gguf",
"llama3.2-3b-q4.gguf"
)
models["coder"] = load_model(
"Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF",
"qwen2.5-coder-1.5b-instruct-q4_k_m.gguf",
True
)
# =========================
# APP
# =========================
@asynccontextmanager
async def lifespan(app: FastAPI):
load_models()
global vision_model, vision_processor
vision_processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
vision_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# =========================
# STOP
# =========================
@app.post("/v1/stop")
def stop(data: dict):
stop_flags[data.get("conversation_id")] = True
return {"status": "stopped"}
def get_next_branch(parent_id):
res = supabase.table("messages") \
.select("branch_id") \
.eq("parent_id", parent_id) \
.execute()
existing = [m["branch_id"] for m in res.data if m.get("branch_id")]
return max(existing, default=0) + 1
# =========================
# HANDLER
# =========================
def handle_chat(model_name, req: ChatRequest):
llm = models[model_name]
prompt = build_prompt(req.messages, req.user_id, req.conversation_id, model_name)
# πŸ”₯ tuned for coder stability
if model_name == "coder":
temp, rp, tp = 0.3, 1.35, 0.85
max_tokens = 2048
else:
temp, rp, tp = req.temperature, 1.15, 0.9
max_tokens = 2048
if req.stream:
def generate():
output = ""
stream = llm(
prompt,
max_tokens=max_tokens,
temperature=temp,
top_p=tp,
repeat_penalty=rp,
stop=[
"<|eot_id|>",
"<|end_of_text|>",
"<|eof|>",
"\nUser:",
"\nHuman:",
"\nAssistant:"
],
stream=True
)
for chunk in stream:
if stop_flags.get(req.conversation_id):
stop_flags[req.conversation_id] = False
break
token = chunk["choices"][0]["text"]
output += token
yield f"data: {json.dumps({'choices':[{'delta':{'content':token}}]})}\n\n"
output_clean = clean_output(output)
yield "event: done\ndata: {}\n\n"
yield "data: [DONE]\n\n"
def save_async():
# πŸ”₯ NORMAL MODE (NO BRANCH)
if not req.branch:
for msg in req.messages:
save_message(req.conversation_id, msg["role"], msg["content"])
save_message(
req.conversation_id,
"assistant",
output_clean
)
# πŸ”₯ BRANCH MODE
else:
branch_id = get_next_branch(req.parent_id)
save_message(
req.conversation_id,
"assistant",
output_clean,
parent_id=req.parent_id,
branch_id=branch_id
)
threading.Thread(target=save_async).start()
return StreamingResponse(generate(), media_type="text/event-stream")
output = llm(
prompt,
max_tokens=max_tokens,
temperature=temp,
top_p=tp,
repeat_penalty=rp,
stop=[
"<|eot_id|>",
"<|end_of_text|>",
"<|eof|>",
"\nUser:",
"\nHuman:",
"\nAssistant:"
]
)
text = clean_output(output["choices"][0]["text"])
def save_non_stream():
if not req.branch:
for m in req.messages:
save_message(req.conversation_id, m["role"], m["content"])
save_message(
req.conversation_id,
"assistant",
text
)
else:
branch_id = get_next_branch(req.parent_id)
save_message(
req.conversation_id,
"assistant",
text,
parent_id=req.parent_id,
branch_id=branch_id
)
threading.Thread(target=save_non_stream).start()
return {
"choices":[{"message":{"role":"assistant","content":text}}],
"done":True
}
# =========================
# ROUTES
# =========================
@app.post("/v1/chat/llama")
async def llama(req: ChatRequest):
return handle_chat("llama", req)
@app.post("/v1/chat/coder")
async def coder(req: ChatRequest):
return handle_chat("coder", req)
@app.post("/v1/feedback")
def feedback(data: dict):
try:
supabase.table("messages").update({
"feedback": data.get("feedback")
}).eq("id", data.get("message_id")).execute()
return {"status": "saved"}
except Exception as e:
return {"error": str(e)}
@app.post("/v1/chat/image/stream")
async def image_chat_stream(
user_id: str = Form(...),
conversation_id: str = Form(...),
prompt: str = Form(...),
file: UploadFile = File(...),
branch: bool = Form(False),
parent_id: str = Form(None)
):
def generate():
try:
image = Image.open(file.file).convert("RGB")
# πŸ”₯ convert image β†’ base64 (NEW)
buffer = BytesIO()
image.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode()
# πŸ”₯ STEP 1: IMAGE ANALYSIS
vision_output = process_image(image, prompt)
# πŸ”₯ notify frontend
yield f"event: vision_done\ndata: {json.dumps({'status': 'done'})}\n\n"
# πŸ”₯ STEP 2: SEND TO LLAMA
enhanced_prompt = f"""
User question:
{prompt}
Image understanding:
{vision_output}
Answer clearly based on the image.
"""
req = ChatRequest(
user_id=user_id,
conversation_id=conversation_id,
messages=[{"role": "user", "content": enhanced_prompt}],
stream=True
)
llm = models["llama"]
final_prompt = build_prompt(
req.messages,
req.user_id,
req.conversation_id,
"llama"
)
stream = llm(
final_prompt,
max_tokens=2048,
temperature=0.7,
top_p=0.9,
repeat_penalty=1.15,
stop=[
"<|eot_id|>",
"<|end_of_text|>",
"<|eof|>"
],
stream=True
)
output = ""
for chunk in stream:
if stop_flags.get(conversation_id):
stop_flags[conversation_id] = False
break
token = chunk["choices"][0]["text"]
output += token
yield f"data: {json.dumps({'choices':[{'delta':{'content':token}}]})}\n\n"
output_clean = clean_output(output)
yield "event: done\ndata: {}\n\n"
yield "data: [DONE]\n\n"
# =========================
# πŸ”₯ SAVE UNIFIED FORMAT
# =========================
def save_async():
# =========================
# 🟒 NORMAL MODE
# =========================
if not branch:
# save user image
save_message(
conversation_id,
"user",
json.dumps({
"type": "image",
"data": img_base64,
"prompt": prompt,
"model": "user_upload"
})
)
# save assistant response
save_message(
conversation_id,
"assistant",
json.dumps({
"type": "text",
"data": output_clean
})
)
# =========================
# πŸ”΅ BRANCH MODE
# =========================
else:
branch_id = get_next_branch(parent_id)
# ❗ DO NOT save user again
# only variation of assistant
save_message(
conversation_id,
"assistant",
json.dumps({
"type": "text",
"data": output_clean
}),
parent_id=parent_id,
branch_id=branch_id
)
threading.Thread(target=save_async).start()
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@app.get("/")
def root():
return {"status": f"Llama API runningπŸš€"}
# =========================
# RUN
# =========================
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)