lahiruchamika27 commited on
Commit
c5ffe98
·
verified ·
1 Parent(s): 64a1d63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -66
app.py CHANGED
@@ -1,74 +1,78 @@
1
  import os
2
  import torch
3
- from flask import Flask, request, jsonify
 
 
4
  from datasets import load_dataset
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
7
- app = Flask(__name__)
8
 
9
- # Global variables to store model, tokenizer, and dataset
10
  model = None
11
  tokenizer = None
12
  dataset = None
13
 
14
- # Function to load the model and dataset
15
- def load_model_and_data():
16
- global model, tokenizer, dataset
17
-
18
- # Load the base model and tokenizer
19
- model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
20
- tokenizer = AutoTokenizer.from_pretrained(model_id)
21
-
22
- # Load the model with reduced precision for efficiency
23
- model = AutoModelForCausalLM.from_pretrained(
24
- model_id,
25
- torch_dtype=torch.float16,
26
- device_map="auto"
27
- )
28
-
29
- # Load your dataset
30
- dataset = load_dataset("lahiruchamika27/tia")
31
- print("Model, tokenizer, and dataset loaded successfully!")
32
 
33
- # Initialize the model on startup
34
- @app.before_request
35
- def before_request():
36
- global model, tokenizer, dataset
37
- if model is None or tokenizer is None or dataset is None:
38
- load_model_and_data()
39
 
40
- # Define chat endpoint
41
- @app.route('/api/chat', methods=['POST'])
42
- def chat():
 
 
 
 
43
  try:
44
- # Get input from request
45
- data = request.json
46
- if not data or 'message' not in data:
47
- return jsonify({"error": "No message provided"}), 400
48
 
49
- user_message = data['message']
 
 
 
 
50
 
51
- # You can optionally retrieve conversation history
52
- conversation_history = data.get('history', [])
53
-
54
- # Create input for the model
55
- # Format may need adjustment based on your model's expected format
56
- if conversation_history:
 
 
 
 
 
 
 
 
 
 
 
 
57
  full_prompt = ""
58
- for turn in conversation_history:
59
- if 'user' in turn:
60
- full_prompt += f"User: {turn['user']}\n"
61
- if 'assistant' in turn:
62
- full_prompt += f"Assistant: {turn['assistant']}\n"
63
 
64
- full_prompt += f"User: {user_message}\nAssistant:"
65
  else:
66
- full_prompt = f"User: {user_message}\nAssistant:"
67
 
68
  # Tokenize and generate
69
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
70
 
71
- # Generate response
72
  with torch.no_grad():
73
  outputs = model.generate(
74
  inputs["input_ids"],
@@ -81,29 +85,33 @@ def chat():
81
  # Decode the output
82
  response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
83
 
84
- return jsonify({"response": response.strip()})
85
 
86
  except Exception as e:
87
- return jsonify({"error": str(e)}), 500
88
 
89
- # Optional endpoint to get examples from your dataset
90
- @app.route('/api/examples', methods=['GET'])
91
- def get_examples():
 
 
 
 
92
  try:
93
- # Get a sample from the dataset
94
- num_examples = int(request.args.get('count', 5))
95
- split = request.args.get('split', 'train')
96
-
97
  if split in dataset:
98
- examples = dataset[split][:num_examples]
99
- return jsonify({"examples": examples})
 
100
  else:
101
- return jsonify({"error": f"Split '{split}' not found in dataset"}), 400
102
 
103
  except Exception as e:
104
- return jsonify({"error": str(e)}), 500
 
 
 
 
105
 
106
- if __name__ == '__main__':
107
- # Get port from environment variable for HF Spaces compatibility
108
- port = int(os.environ.get('PORT', 7860))
109
- app.run(host='0.0.0.0', port=port)
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from typing import List, Dict, Optional
6
  from datasets import load_dataset
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import uvicorn
9
 
10
+ app = FastAPI()
11
 
12
+ # Global variables
13
  model = None
14
  tokenizer = None
15
  dataset = None
16
 
17
+ # Pydantic models for request/response
18
+ class ChatTurn(BaseModel):
19
+ user: Optional[str] = None
20
+ assistant: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ class ChatRequest(BaseModel):
23
+ message: str
24
+ history: Optional[List[ChatTurn]] = []
 
 
 
25
 
26
+ class ChatResponse(BaseModel):
27
+ response: str
28
+
29
+ # Load model on startup
30
+ @app.on_event("startup")
31
+ async def startup_event():
32
+ global model, tokenizer, dataset
33
  try:
34
+ # Load the model and tokenizer
35
+ model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
36
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
 
37
 
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_id,
40
+ torch_dtype=torch.float16,
41
+ device_map="auto"
42
+ )
43
 
44
+ # Load dataset
45
+ dataset = load_dataset("lahiruchamika27/tia")
46
+ print("Model, tokenizer, and dataset loaded successfully!")
47
+ except Exception as e:
48
+ print(f"Error loading model: {str(e)}")
49
+ # Continue without failing - we'll handle errors in the endpoints
50
+
51
+ @app.post("/api/chat", response_model=ChatResponse)
52
+ async def chat(request: ChatRequest):
53
+ global model, tokenizer
54
+
55
+ # Ensure model is loaded
56
+ if model is None or tokenizer is None:
57
+ raise HTTPException(status_code=500, detail="Model or tokenizer not loaded")
58
+
59
+ try:
60
+ # Format conversation
61
+ if request.history:
62
  full_prompt = ""
63
+ for turn in request.history:
64
+ if turn.user:
65
+ full_prompt += f"User: {turn.user}\n"
66
+ if turn.assistant:
67
+ full_prompt += f"Assistant: {turn.assistant}\n"
68
 
69
+ full_prompt += f"User: {request.message}\nAssistant:"
70
  else:
71
+ full_prompt = f"User: {request.message}\nAssistant:"
72
 
73
  # Tokenize and generate
74
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
75
 
 
76
  with torch.no_grad():
77
  outputs = model.generate(
78
  inputs["input_ids"],
 
85
  # Decode the output
86
  response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
87
 
88
+ return ChatResponse(response=response.strip())
89
 
90
  except Exception as e:
91
+ raise HTTPException(status_code=500, detail=str(e))
92
 
93
+ @app.get("/api/examples")
94
+ async def get_examples(count: int = 5, split: str = "train"):
95
+ global dataset
96
+
97
+ if dataset is None:
98
+ raise HTTPException(status_code=500, detail="Dataset not loaded")
99
+
100
  try:
 
 
 
 
101
  if split in dataset:
102
+ # Convert dataset items to dict for easier JSON serialization
103
+ examples = [dict(item) for item in dataset[split][:count]]
104
+ return {"examples": examples}
105
  else:
106
+ raise HTTPException(status_code=400, detail=f"Split '{split}' not found in dataset")
107
 
108
  except Exception as e:
109
+ raise HTTPException(status_code=500, detail=str(e))
110
+
111
+ @app.get("/health")
112
+ async def health_check():
113
+ return {"status": "ok", "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None}
114
 
115
+ if __name__ == "__main__":
116
+ port = int(os.environ.get("PORT", 7860))
117
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)