Spaces:
Running
Running
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download | |
| from supabase import create_client | |
| import os, json, uvicorn, threading | |
| from contextlib import asynccontextmanager | |
| from fastapi import UploadFile, File, Form | |
| from PIL import Image | |
| import torch | |
| import base64 | |
| from io import BytesIO | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| SUPABASE_URL = os.getenv("SUPABASE_URL") | |
| SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
| supabase = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| models = {} | |
| stop_flags = {} | |
| vision_model = None | |
| vision_processor = None | |
| # ========================= | |
| # REQUEST | |
| # ========================= | |
| class ChatRequest(BaseModel): | |
| user_id: str | |
| conversation_id: str | |
| messages: list | |
| temperature: float = 0.7 | |
| stream: bool = False | |
| branch: bool = False | |
| parent_id: str | None = None | |
| # ========================= | |
| # SYSTEM PROMPTS | |
| # ========================= | |
| LLAMA_PROMPT = """You are a friendly, capable, and reliable AI assistant. | |
| - Answer clearly and accurately | |
| - Use clean structure | |
| - Avoid repetition | |
| - Finish naturally | |
| """ | |
| # π₯ STRONG CODER PROMPT (ANTI-HALLUCINATION) | |
| CODER_PROMPT = """You are a strict and precise programming assistant. | |
| Rules: | |
| - Answer ONLY the user's latest question | |
| - DO NOT continue previous conversations | |
| - DO NOT invent new questions | |
| - DO NOT simulate dialogue (no "Human:" or "Assistant:") | |
| - DO NOT repeat yourself | |
| - If code is needed β return clean code in ``` blocks | |
| - If explanation is needed β keep it short and relevant | |
| - Stop immediately after completing the answer | |
| """ | |
| MODEL_PROMPTS = { | |
| "llama": LLAMA_PROMPT, | |
| "coder": CODER_PROMPT, | |
| } | |
| # ========================= | |
| # CLEAN OUTPUT | |
| # ========================= | |
| def clean_output(text): | |
| stop_words = [ | |
| "<|eot_id|>", | |
| "<|end_of_text|>", | |
| "<|eof|>", | |
| "Human:", | |
| "Assistant:", | |
| "User:" | |
| ] | |
| for w in stop_words: | |
| if w in text: | |
| text = text.split(w)[0] | |
| return text.strip() | |
| def process_image(image: Image.Image, user_prompt: str): | |
| inputs = vision_processor(images=image, return_tensors="pt") | |
| out = vision_model.generate( | |
| **inputs, | |
| max_new_tokens=50 | |
| ) | |
| caption = vision_processor.decode(out[0], skip_special_tokens=True) | |
| # π₯ combine with user prompt (VERY IMPORTANT) | |
| return f"{caption}. User focus: {user_prompt}" | |
| # ========================= | |
| # CHAT STORAGE | |
| # ========================= | |
| def get_messages(cid): | |
| res = supabase.table("messages").select("role,content").eq("conversation_id", cid).order("created_at").execute() | |
| return res.data or [] | |
| def save_message(cid, role, content, parent_id=None, branch_id=None): | |
| supabase.table("messages").insert({ | |
| "conversation_id": cid, | |
| "role": role, | |
| "content": content, | |
| "parent_id": parent_id, | |
| "branch_id": branch_id | |
| }).execute() | |
| # ========================= | |
| # PROMPT BUILDER | |
| # ========================= | |
| def build_prompt(messages, user_id, cid, model_name): | |
| base = MODEL_PROMPTS[model_name] | |
| # β fetch history | |
| history = get_messages(cid) | |
| # ========================= | |
| # π₯ FILTER BAD HISTORY | |
| # ========================= | |
| filtered_history = [] | |
| for msg in history: | |
| try: | |
| content = json.loads(msg["content"]) | |
| # β skip image uploads (BLIP / user image) | |
| if isinstance(content, dict) and content.get("type") == "image": | |
| continue | |
| # β skip if marked from image model | |
| if isinstance(content, dict) and content.get("model") == "user_upload": | |
| continue | |
| # β keep only text messages | |
| if isinstance(content, dict) and content.get("type") == "text": | |
| filtered_history.append({ | |
| "role": msg["role"], | |
| "content": content["data"] | |
| }) | |
| else: | |
| # fallback (old plain text messages) | |
| filtered_history.append(msg) | |
| except: | |
| filtered_history.append(msg) | |
| # limit history | |
| filtered_history = filtered_history[-6:] | |
| # ========================= | |
| # LLAMA FORMAT (UNCHANGED) | |
| # ========================= | |
| if model_name == "llama": | |
| prompt = "<|begin_of_text|>\n" | |
| prompt += "<|start_header_id|>system<|end_header_id|>\n" | |
| prompt += base + "\n<|eot_id|>\n" | |
| for msg in (filtered_history + messages): | |
| prompt += f"<|start_header_id|>{msg['role']}<|end_header_id|>\n{msg['content']}\n<|eot_id|>\n" | |
| prompt += "<|start_header_id|>assistant<|end_header_id|>\n" | |
| return prompt | |
| # ========================= | |
| # CODER (UNCHANGED LOGIC) | |
| # ========================= | |
| last_user = "" | |
| for m in reversed(messages): | |
| if m["role"] == "user": | |
| last_user = m["content"] | |
| break | |
| return f"""System: | |
| {base} | |
| User: | |
| {last_user} | |
| Assistant: | |
| """ | |
| # ========================= | |
| # MODEL LOADING | |
| # ========================= | |
| def load_model(repo, file, optimized=False): | |
| if optimized: | |
| return Llama( | |
| model_path=hf_hub_download(repo_id=repo, filename=file, token=HF_TOKEN, cache_dir="/data"), | |
| n_ctx=2048, | |
| n_threads=3, | |
| n_batch=1024, | |
| use_mmap=True, | |
| use_mlock=True, | |
| f16_kv=True, | |
| verbose=False | |
| ) | |
| return Llama( | |
| model_path=hf_hub_download(repo_id=repo, filename=file, token=HF_TOKEN, cache_dir="/data"), | |
| n_ctx=2048, | |
| n_threads=4, | |
| n_batch=512, | |
| use_mmap=True, | |
| use_mlock=True, | |
| f16_kv=True, | |
| verbose=False | |
| ) | |
| def load_models(): | |
| models["llama"] = load_model( | |
| "Valtry/llama3.2-3b-q4-gguf", | |
| "llama3.2-3b-q4.gguf" | |
| ) | |
| models["coder"] = load_model( | |
| "Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF", | |
| "qwen2.5-coder-1.5b-instruct-q4_k_m.gguf", | |
| True | |
| ) | |
| # ========================= | |
| # APP | |
| # ========================= | |
| async def lifespan(app: FastAPI): | |
| load_models() | |
| global vision_model, vision_processor | |
| vision_processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| vision_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ========================= | |
| # STOP | |
| # ========================= | |
| def stop(data: dict): | |
| stop_flags[data.get("conversation_id")] = True | |
| return {"status": "stopped"} | |
| def get_next_branch(parent_id): | |
| res = supabase.table("messages") \ | |
| .select("branch_id") \ | |
| .eq("parent_id", parent_id) \ | |
| .execute() | |
| existing = [m["branch_id"] for m in res.data if m.get("branch_id")] | |
| return max(existing, default=0) + 1 | |
| # ========================= | |
| # HANDLER | |
| # ========================= | |
| def handle_chat(model_name, req: ChatRequest): | |
| llm = models[model_name] | |
| prompt = build_prompt(req.messages, req.user_id, req.conversation_id, model_name) | |
| # π₯ tuned for coder stability | |
| if model_name == "coder": | |
| temp, rp, tp = 0.3, 1.35, 0.85 | |
| max_tokens = 2048 | |
| else: | |
| temp, rp, tp = req.temperature, 1.15, 0.9 | |
| max_tokens = 2048 | |
| if req.stream: | |
| def generate(): | |
| output = "" | |
| stream = llm( | |
| prompt, | |
| max_tokens=max_tokens, | |
| temperature=temp, | |
| top_p=tp, | |
| repeat_penalty=rp, | |
| stop=[ | |
| "<|eot_id|>", | |
| "<|end_of_text|>", | |
| "<|eof|>", | |
| "\nUser:", | |
| "\nHuman:", | |
| "\nAssistant:" | |
| ], | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| if stop_flags.get(req.conversation_id): | |
| stop_flags[req.conversation_id] = False | |
| break | |
| token = chunk["choices"][0]["text"] | |
| output += token | |
| yield f"data: {json.dumps({'choices':[{'delta':{'content':token}}]})}\n\n" | |
| output_clean = clean_output(output) | |
| yield "event: done\ndata: {}\n\n" | |
| yield "data: [DONE]\n\n" | |
| def save_async(): | |
| # π₯ NORMAL MODE (NO BRANCH) | |
| if not req.branch: | |
| for msg in req.messages: | |
| save_message(req.conversation_id, msg["role"], msg["content"]) | |
| save_message( | |
| req.conversation_id, | |
| "assistant", | |
| output_clean | |
| ) | |
| # π₯ BRANCH MODE | |
| else: | |
| branch_id = get_next_branch(req.parent_id) | |
| save_message( | |
| req.conversation_id, | |
| "assistant", | |
| output_clean, | |
| parent_id=req.parent_id, | |
| branch_id=branch_id | |
| ) | |
| threading.Thread(target=save_async).start() | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| output = llm( | |
| prompt, | |
| max_tokens=max_tokens, | |
| temperature=temp, | |
| top_p=tp, | |
| repeat_penalty=rp, | |
| stop=[ | |
| "<|eot_id|>", | |
| "<|end_of_text|>", | |
| "<|eof|>", | |
| "\nUser:", | |
| "\nHuman:", | |
| "\nAssistant:" | |
| ] | |
| ) | |
| text = clean_output(output["choices"][0]["text"]) | |
| def save_non_stream(): | |
| if not req.branch: | |
| for m in req.messages: | |
| save_message(req.conversation_id, m["role"], m["content"]) | |
| save_message( | |
| req.conversation_id, | |
| "assistant", | |
| text | |
| ) | |
| else: | |
| branch_id = get_next_branch(req.parent_id) | |
| save_message( | |
| req.conversation_id, | |
| "assistant", | |
| text, | |
| parent_id=req.parent_id, | |
| branch_id=branch_id | |
| ) | |
| threading.Thread(target=save_non_stream).start() | |
| return { | |
| "choices":[{"message":{"role":"assistant","content":text}}], | |
| "done":True | |
| } | |
| # ========================= | |
| # ROUTES | |
| # ========================= | |
| async def llama(req: ChatRequest): | |
| return handle_chat("llama", req) | |
| async def coder(req: ChatRequest): | |
| return handle_chat("coder", req) | |
| def feedback(data: dict): | |
| try: | |
| supabase.table("messages").update({ | |
| "feedback": data.get("feedback") | |
| }).eq("id", data.get("message_id")).execute() | |
| return {"status": "saved"} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| async def image_chat_stream( | |
| user_id: str = Form(...), | |
| conversation_id: str = Form(...), | |
| prompt: str = Form(...), | |
| file: UploadFile = File(...), | |
| branch: bool = Form(False), | |
| parent_id: str = Form(None) | |
| ): | |
| def generate(): | |
| try: | |
| image = Image.open(file.file).convert("RGB") | |
| # π₯ convert image β base64 (NEW) | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| img_base64 = base64.b64encode(buffer.getvalue()).decode() | |
| # π₯ STEP 1: IMAGE ANALYSIS | |
| vision_output = process_image(image, prompt) | |
| # π₯ notify frontend | |
| yield f"event: vision_done\ndata: {json.dumps({'status': 'done'})}\n\n" | |
| # π₯ STEP 2: SEND TO LLAMA | |
| enhanced_prompt = f""" | |
| User question: | |
| {prompt} | |
| Image understanding: | |
| {vision_output} | |
| Answer clearly based on the image. | |
| """ | |
| req = ChatRequest( | |
| user_id=user_id, | |
| conversation_id=conversation_id, | |
| messages=[{"role": "user", "content": enhanced_prompt}], | |
| stream=True | |
| ) | |
| llm = models["llama"] | |
| final_prompt = build_prompt( | |
| req.messages, | |
| req.user_id, | |
| req.conversation_id, | |
| "llama" | |
| ) | |
| stream = llm( | |
| final_prompt, | |
| max_tokens=2048, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repeat_penalty=1.15, | |
| stop=[ | |
| "<|eot_id|>", | |
| "<|end_of_text|>", | |
| "<|eof|>" | |
| ], | |
| stream=True | |
| ) | |
| output = "" | |
| for chunk in stream: | |
| if stop_flags.get(conversation_id): | |
| stop_flags[conversation_id] = False | |
| break | |
| token = chunk["choices"][0]["text"] | |
| output += token | |
| yield f"data: {json.dumps({'choices':[{'delta':{'content':token}}]})}\n\n" | |
| output_clean = clean_output(output) | |
| yield "event: done\ndata: {}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # ========================= | |
| # π₯ SAVE UNIFIED FORMAT | |
| # ========================= | |
| def save_async(): | |
| # ========================= | |
| # π’ NORMAL MODE | |
| # ========================= | |
| if not branch: | |
| # save user image | |
| save_message( | |
| conversation_id, | |
| "user", | |
| json.dumps({ | |
| "type": "image", | |
| "data": img_base64, | |
| "prompt": prompt, | |
| "model": "user_upload" | |
| }) | |
| ) | |
| # save assistant response | |
| save_message( | |
| conversation_id, | |
| "assistant", | |
| json.dumps({ | |
| "type": "text", | |
| "data": output_clean | |
| }) | |
| ) | |
| # ========================= | |
| # π΅ BRANCH MODE | |
| # ========================= | |
| else: | |
| branch_id = get_next_branch(parent_id) | |
| # β DO NOT save user again | |
| # only variation of assistant | |
| save_message( | |
| conversation_id, | |
| "assistant", | |
| json.dumps({ | |
| "type": "text", | |
| "data": output_clean | |
| }), | |
| parent_id=parent_id, | |
| branch_id=branch_id | |
| ) | |
| threading.Thread(target=save_async).start() | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| return StreamingResponse(generate(), media_type="text/event-stream") | |
| def root(): | |
| return {"status": f"Llama API runningπ"} | |
| # ========================= | |
| # RUN | |
| # ========================= | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) |