Adedoyinjames commited on
Commit
d8f5f0b
·
verified ·
1 Parent(s): 7f9fe6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -92
app.py CHANGED
@@ -1,144 +1,154 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
 
4
  import os
 
 
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
- import threading
 
 
 
 
8
 
9
- # Create FastAPI app
10
- app = FastAPI()
11
 
12
- # Get token from Space secrets
13
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
15
- # Global variables for model and tokenizer
16
  model = None
17
  tokenizer = None
18
- model_loading = True
19
 
 
20
  def load_model():
21
- """Load model in background thread"""
22
  global model, tokenizer, model_loading
23
  try:
24
- # Load model and tokenizer explicitly - this is more reliable than pipeline
25
- model_name = "Adedoyinjames/YAH_Tech_Ai"
26
-
27
- # Load tokenizer
28
  tokenizer = AutoTokenizer.from_pretrained(
29
- model_name,
30
- use_auth_token=HF_TOKEN,
31
- trust_remote_code=True
32
  )
33
-
34
- # Load model
35
  model = AutoModelForCausalLM.from_pretrained(
36
- model_name,
37
  use_auth_token=HF_TOKEN,
38
- torch_dtype=torch.float16,
39
- device_map="auto",
40
  trust_remote_code=True
41
  )
42
-
43
- model_loading = False
44
  print("✅ Model loaded successfully!")
45
-
46
  except Exception as e:
 
47
  print(f"❌ Error loading model: {e}")
48
- model_loading = False
 
49
 
50
- # Start model loading in background
51
  threading.Thread(target=load_model, daemon=True).start()
52
 
53
- class ChatRequest(BaseModel):
54
- message: str
55
- history: list = []
56
-
57
- def respond(message, history):
58
- """Handle chat responses with proper error handling"""
59
  if model_loading:
60
- return "⚠️ Model is still loading. Please wait a moment and try again."
61
-
62
  if model is None or tokenizer is None:
63
- return "⚠️ Model failed to load. Please check the logs."
64
-
65
- try:
66
- # Prepare input - format for chat models
67
- if history:
68
- # For multi-turn conversation, format the history
69
- formatted_history = "\n".join([f"User: {h[0]}\nAssistant: {h[1]}" for h in history])
70
- full_prompt = f"{formatted_history}\nUser: {message}\nAssistant:"
71
- else:
72
- # For first message
73
- full_prompt = f"User: {message}\nAssistant:"
74
-
75
- # Tokenize input
76
- inputs = tokenizer.encode(full_prompt, return_tensors="pt")
77
-
78
- # Generate response
79
- with torch.no_grad():
80
- outputs = model.generate(
81
- inputs,
82
- max_length=len(inputs[0]) + 100, # Generate up to 100 new tokens
83
- max_new_tokens=100,
84
- do_sample=True,
85
- temperature=0.7,
86
- top_p=0.9,
87
- pad_token_id=tokenizer.eos_token_id,
88
- repetition_penalty=1.1
89
- )
90
-
91
- # Decode only the new tokens (remove the input)
92
- response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
93
-
94
- # Clean up the response
95
- response = response.strip()
96
-
97
- return response
98
-
99
- except Exception as e:
100
- return f"❌ Error generating response: {str(e)}"
101
 
102
- # API endpoint for external apps
103
  @app.post("/chat")
104
- async def chat_api(request: ChatRequest):
105
  if model_loading:
106
  raise HTTPException(status_code=503, detail="Model is still loading")
107
-
108
  if model is None or tokenizer is None:
109
  raise HTTPException(status_code=500, detail="Model failed to load")
110
-
111
  try:
112
- response = respond(request.message, request.history)
113
- return {"response": response}
114
  except Exception as e:
115
  raise HTTPException(status_code=500, detail=str(e))
116
 
117
- # Health check endpoint
118
  @app.get("/health")
119
- async def health_check():
 
120
  if model_loading:
121
  return {"status": "loading"}
122
- elif model is None:
123
  return {"status": "error"}
124
- else:
125
- return {"status": "ready"}
 
 
 
 
 
 
 
126
 
127
- # Create a chat interface for web testing
128
  iface = gr.ChatInterface(
129
- fn=respond,
130
  title="YAH Tech AI Chatbot",
131
- description="Ask YAH Tech AI anything! Powered by advanced language models.",
132
  examples=[
133
  "Hello! How can you help me?",
134
  "What is artificial intelligence?",
135
  "Tell me about machine learning"
136
  ],
137
- theme="soft"
 
 
 
138
  )
139
 
140
- # Mount Gradio interface to FastAPI
141
- app = gr.mount_gradio_app(app, iface, path="/")
 
 
142
 
 
 
 
143
  if __name__ == "__main__":
144
- iface.launch(share=False)
 
 
1
+ # --------------------------------------------------------------
2
+ # app.py – A self‑contained Gradio + FastAPI chatbot
3
+ # --------------------------------------------------------------
4
+
5
  import os
6
+ import threading
7
+ import torch
8
+ import gradio as gr
9
  from fastapi import FastAPI, HTTPException
10
  from pydantic import BaseModel
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ # ------------------- 1️⃣ GLOBAL SETTINGS ----------------------
14
+ # Model identifier (change only if you move to another model)
15
+ MODEL_ID = "Adedoyinjames/YAH_Tech_Ai"
16
 
17
+ # Read token from Space secrets (will be None for public models)
18
+ HF_TOKEN = os.getenv("HF_TOKEN") # <-- automatically set by Secrets
19
 
20
+ # FastAPI app (will also host the Gradio UI)
21
+ api_app = FastAPI()
22
 
23
+ # Place‑holders that will be filled once the model finishes loading
24
  model = None
25
  tokenizer = None
26
+ model_loading = True # flag used by the endpoints
27
 
28
+ # ------------------- 2️⃣ MODEL LOADER ------------------------
29
  def load_model():
30
+ """Run in a background thread so the Space starts instantly."""
31
  global model, tokenizer, model_loading
32
  try:
33
+ # ---- Load tokenizer -------------------------------------------------
 
 
 
34
  tokenizer = AutoTokenizer.from_pretrained(
35
+ MODEL_ID,
36
+ use_auth_token=HF_TOKEN, # works with None (public model) or token (private)
37
+ trust_remote_code=True # some community models need this
38
  )
39
+
40
+ # ---- Load model ------------------------------------------------------
41
  model = AutoModelForCausalLM.from_pretrained(
42
+ MODEL_ID,
43
  use_auth_token=HF_TOKEN,
44
+ torch_dtype=torch.float16, # half‑precision saves VRAM
45
+ device_map="auto", # puts layers on GPU/CPU as needed
46
  trust_remote_code=True
47
  )
 
 
48
  print("✅ Model loaded successfully!")
 
49
  except Exception as e:
50
+ # Anything that goes wrong will be printed in the log – you can see it
51
  print(f"❌ Error loading model: {e}")
52
+ finally:
53
+ model_loading = False # whether success or failure, we are done loading
54
 
55
+ # Start the loader as soon as the container boots
56
  threading.Thread(target=load_model, daemon=True).start()
57
 
58
+ # ------------------- 3️⃣ RESPONSE LOGIC ----------------------
59
+ def generate_response(message: str, history: list):
60
+ """Core function used by both Gradio UI and the API."""
 
 
 
61
  if model_loading:
62
+ return "⚠️ Model is still loading please wait a few seconds and try again."
63
+
64
  if model is None or tokenizer is None:
65
+ return " Model failed to load. Check the Space logs for details."
66
+
67
+ # Build a prompt that contains the previous turns (if any)
68
+ if history:
69
+ # history is a list of tuples: [(user, bot), (user, bot), ...]
70
+ formatted = "\n".join([f"User: {u}\nAssistant: {b}" for u, b in history])
71
+ prompt = f"{formatted}\nUser: {message}\nAssistant:"
72
+ else:
73
+ prompt = f"User: {message}\nAssistant:"
74
+
75
+ # Tokenize
76
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
77
+
78
+ # Generate
79
+ with torch.no_grad():
80
+ output_ids = model.generate(
81
+ input_ids,
82
+ max_new_tokens=100,
83
+ do_sample=True,
84
+ temperature=0.7,
85
+ top_p=0.9,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ repetition_penalty=1.1
88
+ )
89
+
90
+ # Remove the prompt part from the output
91
+ answer = tokenizer.decode(output_ids[0][len(input_ids[0]):],
92
+ skip_special_tokens=True).strip()
93
+ return answer
94
+
95
+ # ------------------- 4️⃣ FASTAPI ENDPOINT --------------------
96
+ class ChatRequest(BaseModel):
97
+ message: str
98
+ history: list = [] # optional list of [user, bot] pairs
 
 
 
 
99
 
 
100
  @app.post("/chat")
101
+ async def chat_endpoint(req: ChatRequest):
102
  if model_loading:
103
  raise HTTPException(status_code=503, detail="Model is still loading")
 
104
  if model is None or tokenizer is None:
105
  raise HTTPException(status_code=500, detail="Model failed to load")
 
106
  try:
107
+ reply = generate_response(req.message, req.history)
108
+ return {"response": reply}
109
  except Exception as e:
110
  raise HTTPException(status_code=500, detail=str(e))
111
 
 
112
  @app.get("/health")
113
+ async def health():
114
+ """Simple health‑check for monitoring."""
115
  if model_loading:
116
  return {"status": "loading"}
117
+ if model is None:
118
  return {"status": "error"}
119
+ return {"status": "ready"}
120
+
121
+ # ------------------- 5️⃣ GRADIO UI ---------------------------
122
+ def gradio_chat(message, history):
123
+ """Wrapper used by Gradio – it returns (bot_reply, updated_history)."""
124
+ bot_reply = generate_response(message, history)
125
+ # Gradio expects the new history as a list of [user, bot] pairs
126
+ history.append((message, bot_reply))
127
+ return "", history # first element clears the text box
128
 
 
129
  iface = gr.ChatInterface(
130
+ fn=gradio_chat,
131
  title="YAH Tech AI Chatbot",
132
+ description="Ask anything the model runs completely for free in this Space.",
133
  examples=[
134
  "Hello! How can you help me?",
135
  "What is artificial intelligence?",
136
  "Tell me about machine learning"
137
  ],
138
+ theme="soft",
139
+ # Force all helper processes onto the same port to avoid the “Invalid port” warnings
140
+ server_port=7860,
141
+ server_name="0.0.0.0"
142
  )
143
 
144
+ # --------------------------------------------------------------
145
+ # Mount the Gradio UI onto the same FastAPI app
146
+ # --------------------------------------------------------------
147
+ app = gr.mount_gradio_app(api_app, iface, path="/") # UI lives at https://…/ (root)
148
 
149
+ # --------------------------------------------------------------
150
+ # If you run the script locally (outside a Space) this block fires
151
+ # --------------------------------------------------------------
152
  if __name__ == "__main__":
153
+ # `share=False` is fine inside a Space; set to True if you run locally and want a public link.
154
+ iface.launch(share=False, server_port=7860, server_name="0.0.0.0")