Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------------- | |
| # app.py – A self‑contained Gradio + FastAPI chatbot | |
| # -------------------------------------------------------------- | |
| import os | |
| import threading | |
| import torch | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # ------------------- 1️⃣ GLOBAL SETTINGS ---------------------- | |
| # Model identifier (change only if you move to another model) | |
| MODEL_ID = "Adedoyinjames/YAH_Tech_Ai" | |
| # Read token from Space secrets (will be None for public models) | |
| HF_TOKEN = os.getenv("HF_TOKEN") # <-- automatically set by Secrets | |
| # FastAPI app (will also host the Gradio UI) | |
| api_app = FastAPI() | |
| # Place‑holders that will be filled once the model finishes loading | |
| model = None | |
| tokenizer = None | |
| model_loading = True # flag used by the endpoints | |
| # ------------------- 2️⃣ MODEL LOADER ------------------------ | |
| def load_model(): | |
| """Run in a background thread so the Space starts instantly.""" | |
| global model, tokenizer, model_loading | |
| try: | |
| # ---- Load tokenizer ------------------------------------------------- | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| use_auth_token=HF_TOKEN, # works with None (public model) or token (private) | |
| trust_remote_code=True # some community models need this | |
| ) | |
| # ---- Load model ------------------------------------------------------ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| use_auth_token=HF_TOKEN, | |
| torch_dtype=torch.float16, # half‑precision saves VRAM | |
| device_map="auto", # puts layers on GPU/CPU as needed | |
| trust_remote_code=True | |
| ) | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| # Anything that goes wrong will be printed in the log – you can see it | |
| print(f"❌ Error loading model: {e}") | |
| finally: | |
| model_loading = False # whether success or failure, we are done loading | |
| # Start the loader as soon as the container boots | |
| threading.Thread(target=load_model, daemon=True).start() | |
| # ------------------- 3️⃣ RESPONSE LOGIC ---------------------- | |
| def generate_response(message: str, history: list): | |
| """Core function used by both Gradio UI and the API.""" | |
| if model_loading: | |
| return "⚠️ Model is still loading – please wait a few seconds and try again." | |
| if model is None or tokenizer is None: | |
| return "❌ Model failed to load. Check the Space logs for details." | |
| # Build a prompt that contains the previous turns (if any) | |
| if history: | |
| # history is a list of tuples: [(user, bot), (user, bot), ...] | |
| formatted = "\n".join([f"User: {u}\nAssistant: {b}" for u, b in history]) | |
| prompt = f"{formatted}\nUser: {message}\nAssistant:" | |
| else: | |
| prompt = f"User: {message}\nAssistant:" | |
| # Tokenize | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
| # Generate | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=100, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| # Remove the prompt part from the output | |
| answer = tokenizer.decode(output_ids[0][len(input_ids[0]):], | |
| skip_special_tokens=True).strip() | |
| return answer | |
| # ------------------- 4️⃣ FASTAPI ENDPOINT -------------------- | |
| class ChatRequest(BaseModel): | |
| message: str | |
| history: list = [] # optional list of [user, bot] pairs | |
| async def chat_endpoint(req: ChatRequest): | |
| if model_loading: | |
| raise HTTPException(status_code=503, detail="Model is still loading") | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=500, detail="Model failed to load") | |
| try: | |
| reply = generate_response(req.message, req.history) | |
| return {"response": reply} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health(): | |
| """Simple health‑check for monitoring.""" | |
| if model_loading: | |
| return {"status": "loading"} | |
| if model is None: | |
| return {"status": "error"} | |
| return {"status": "ready"} | |
| # ------------------- 5️⃣ GRADIO UI --------------------------- | |
| def gradio_chat(message, history): | |
| """Wrapper used by Gradio – it returns (bot_reply, updated_history).""" | |
| bot_reply = generate_response(message, history) | |
| # Gradio expects the new history as a list of [user, bot] pairs | |
| history.append((message, bot_reply)) | |
| return "", history # first element clears the text box | |
| iface = gr.ChatInterface( | |
| fn=gradio_chat, | |
| title="YAH Tech AI Chatbot", | |
| description="Ask anything – the model runs completely for free in this Space.", | |
| examples=[ | |
| "Hello! How can you help me?", | |
| "What is artificial intelligence?", | |
| "Tell me about machine learning" | |
| ], | |
| theme="soft", | |
| # Force all helper processes onto the same port to avoid the “Invalid port” warnings | |
| server_port=7860, | |
| server_name="0.0.0.0" | |
| ) | |
| # -------------------------------------------------------------- | |
| # Mount the Gradio UI onto the same FastAPI app | |
| # -------------------------------------------------------------- | |
| app = gr.mount_gradio_app(api_app, iface, path="/") # UI lives at https://…/ (root) | |
| # -------------------------------------------------------------- | |
| # If you run the script locally (outside a Space) this block fires | |
| # -------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # `share=False` is fine inside a Space; set to True if you run locally and want a public link. | |
| iface.launch(share=False, server_port=7860, server_name="0.0.0.0") |