YAH_Tech_Ai / app.py
Adedoyinjames's picture
Update app.py
d8f5f0b verified
# --------------------------------------------------------------
# app.py – A self‑contained Gradio + FastAPI chatbot
# --------------------------------------------------------------
import os
import threading
import torch
import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# ------------------- 1️⃣ GLOBAL SETTINGS ----------------------
# Model identifier (change only if you move to another model)
MODEL_ID = "Adedoyinjames/YAH_Tech_Ai"
# Read token from Space secrets (will be None for public models)
HF_TOKEN = os.getenv("HF_TOKEN") # <-- automatically set by Secrets
# FastAPI app (will also host the Gradio UI)
api_app = FastAPI()
# Place‑holders that will be filled once the model finishes loading
model = None
tokenizer = None
model_loading = True # flag used by the endpoints
# ------------------- 2️⃣ MODEL LOADER ------------------------
def load_model():
"""Run in a background thread so the Space starts instantly."""
global model, tokenizer, model_loading
try:
# ---- Load tokenizer -------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
use_auth_token=HF_TOKEN, # works with None (public model) or token (private)
trust_remote_code=True # some community models need this
)
# ---- Load model ------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
use_auth_token=HF_TOKEN,
torch_dtype=torch.float16, # half‑precision saves VRAM
device_map="auto", # puts layers on GPU/CPU as needed
trust_remote_code=True
)
print("✅ Model loaded successfully!")
except Exception as e:
# Anything that goes wrong will be printed in the log – you can see it
print(f"❌ Error loading model: {e}")
finally:
model_loading = False # whether success or failure, we are done loading
# Start the loader as soon as the container boots
threading.Thread(target=load_model, daemon=True).start()
# ------------------- 3️⃣ RESPONSE LOGIC ----------------------
def generate_response(message: str, history: list):
"""Core function used by both Gradio UI and the API."""
if model_loading:
return "⚠️ Model is still loading – please wait a few seconds and try again."
if model is None or tokenizer is None:
return "❌ Model failed to load. Check the Space logs for details."
# Build a prompt that contains the previous turns (if any)
if history:
# history is a list of tuples: [(user, bot), (user, bot), ...]
formatted = "\n".join([f"User: {u}\nAssistant: {b}" for u, b in history])
prompt = f"{formatted}\nUser: {message}\nAssistant:"
else:
prompt = f"User: {message}\nAssistant:"
# Tokenize
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Remove the prompt part from the output
answer = tokenizer.decode(output_ids[0][len(input_ids[0]):],
skip_special_tokens=True).strip()
return answer
# ------------------- 4️⃣ FASTAPI ENDPOINT --------------------
class ChatRequest(BaseModel):
message: str
history: list = [] # optional list of [user, bot] pairs
@app.post("/chat")
async def chat_endpoint(req: ChatRequest):
if model_loading:
raise HTTPException(status_code=503, detail="Model is still loading")
if model is None or tokenizer is None:
raise HTTPException(status_code=500, detail="Model failed to load")
try:
reply = generate_response(req.message, req.history)
return {"response": reply}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
"""Simple health‑check for monitoring."""
if model_loading:
return {"status": "loading"}
if model is None:
return {"status": "error"}
return {"status": "ready"}
# ------------------- 5️⃣ GRADIO UI ---------------------------
def gradio_chat(message, history):
"""Wrapper used by Gradio – it returns (bot_reply, updated_history)."""
bot_reply = generate_response(message, history)
# Gradio expects the new history as a list of [user, bot] pairs
history.append((message, bot_reply))
return "", history # first element clears the text box
iface = gr.ChatInterface(
fn=gradio_chat,
title="YAH Tech AI Chatbot",
description="Ask anything – the model runs completely for free in this Space.",
examples=[
"Hello! How can you help me?",
"What is artificial intelligence?",
"Tell me about machine learning"
],
theme="soft",
# Force all helper processes onto the same port to avoid the “Invalid port” warnings
server_port=7860,
server_name="0.0.0.0"
)
# --------------------------------------------------------------
# Mount the Gradio UI onto the same FastAPI app
# --------------------------------------------------------------
app = gr.mount_gradio_app(api_app, iface, path="/") # UI lives at https://…/ (root)
# --------------------------------------------------------------
# If you run the script locally (outside a Space) this block fires
# --------------------------------------------------------------
if __name__ == "__main__":
# `share=False` is fine inside a Space; set to True if you run locally and want a public link.
iface.launch(share=False, server_port=7860, server_name="0.0.0.0")