Trigger82 commited on
Commit
5afdd5f
·
verified ·
1 Parent(s): b84920e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -21
app.py CHANGED
@@ -1,41 +1,134 @@
1
- from fastapi import FastAPI, Request
2
  from fastapi.responses import JSONResponse
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
 
5
 
6
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Load model and tokenizer
9
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
10
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
 
 
 
 
 
11
 
12
- # In-memory history per user
13
  chat_history = {}
14
 
15
- @app.get("/")
16
  async def root():
17
  return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"}
18
 
19
  @app.get("/ai")
20
  async def chat(request: Request):
21
- query_params = dict(request.query_params)
22
- user_input = query_params.get("query", "")
23
- user_id = query_params.get("user_id", "default")
24
-
25
- if not user_input:
26
- return JSONResponse({"error": "Missing 'query' parameter"}, status_code=400)
 
 
 
 
 
 
 
 
 
 
27
 
28
- new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
29
- user_history = chat_history.get(user_id, [])
30
- bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- output_ids = model.generate(bot_input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
33
- response = tokenizer.decode(output_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
 
 
 
34
 
35
- chat_history[user_id] = [bot_input_ids, output_ids]
36
- return JSONResponse({"reply": response})
 
 
 
37
 
38
- # Only needed if running locally, not in Hugging Face Space
39
  if __name__ == "__main__":
40
  import uvicorn
41
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ import os
6
+ import logging
7
 
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Get Hugging Face Space configuration
13
+ HF_SPACE = os.getenv("SPACE_ID", "")
14
+ BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else ""
15
+
16
+ # Initialize FastAPI with correct base path
17
+ app = FastAPI(
18
+ title="DialoGPT API",
19
+ description="Chatbot API using Microsoft's DialoGPT-medium model",
20
+ version="1.0",
21
+ root_path=BASE_PATH,
22
+ docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs",
23
+ redoc_url=None
24
+ )
25
 
26
  # Load model and tokenizer
27
+ try:
28
+ logger.info("Loading tokenizer and model...")
29
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
30
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
31
+ logger.info("Model loaded successfully!")
32
+ except Exception as e:
33
+ logger.error(f"Model loading failed: {str(e)}")
34
+ raise RuntimeError("Model initialization failed") from e
35
 
36
+ # In-memory chat history storage
37
  chat_history = {}
38
 
39
+ @app.get("/", include_in_schema=False)
40
  async def root():
41
  return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"}
42
 
43
  @app.get("/ai")
44
  async def chat(request: Request):
45
+ try:
46
+ # Get query parameters
47
+ user_input = request.query_params.get("query", "").strip()
48
+ user_id = request.query_params.get("user_id", "default").strip()
49
+
50
+ # Validate input
51
+ if not user_input:
52
+ raise HTTPException(
53
+ status_code=400,
54
+ detail="Missing 'query' parameter. Usage: /ai?query=Hello&user_id=yourname"
55
+ )
56
+ if len(user_input) > 200:
57
+ raise HTTPException(
58
+ status_code=400,
59
+ detail="Query too long (max 200 characters)"
60
+ )
61
 
62
+ # Process the query
63
+ new_input_ids = tokenizer.encode(
64
+ user_input + tokenizer.eos_token,
65
+ return_tensors='pt'
66
+ )
67
+
68
+ # Retrieve user history
69
+ user_history = chat_history.get(user_id, [])
70
+
71
+ # Generate bot response
72
+ bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids
73
+ output_ids = model.generate(
74
+ bot_input_ids,
75
+ max_new_tokens=100,
76
+ pad_token_id=tokenizer.eos_token_id,
77
+ do_sample=True,
78
+ top_k=50,
79
+ top_p=0.95
80
+ )
81
+
82
+ # Decode and clean response
83
+ response = tokenizer.decode(
84
+ output_ids[:, bot_input_ids.shape[-1]:][0],
85
+ skip_special_tokens=True
86
+ ).strip()
87
+
88
+ # Update chat history
89
+ chat_history[user_id] = [bot_input_ids, output_ids]
90
+
91
+ return {"reply": response}
92
+
93
+ except torch.cuda.OutOfMemoryError:
94
+ logger.error("CUDA out of memory error")
95
+ # Clear history to free memory
96
+ if user_id in chat_history:
97
+ del chat_history[user_id]
98
+ raise HTTPException(
99
+ status_code=500,
100
+ detail="Memory error. Conversation history cleared. Please try again."
101
+ )
102
+
103
+ except Exception as e:
104
+ logger.error(f"Processing error: {str(e)}")
105
+ raise HTTPException(
106
+ status_code=500,
107
+ detail=f"Processing error: {str(e)}"
108
+ ) from e
109
 
110
+ @app.get("/health")
111
+ async def health_check():
112
+ return {
113
+ "status": "healthy",
114
+ "model": "microsoft/DialoGPT-medium",
115
+ "users": len(chat_history),
116
+ "space_id": HF_SPACE
117
+ }
118
 
119
+ @app.get("/reset")
120
+ async def reset_history(user_id: str = "default"):
121
+ if user_id in chat_history:
122
+ del chat_history[user_id]
123
+ return {"status": "success", "message": f"History cleared for user {user_id}"}
124
 
125
+ # Only run with uvicorn when executing locally
126
  if __name__ == "__main__":
127
  import uvicorn
128
+ uvicorn.run(
129
+ app,
130
+ host="0.0.0.0",
131
+ port=7860,
132
+ log_level="info",
133
+ timeout_keep_alive=30
134
+ )