Danaasa commited on
Commit
33b8d0a
·
verified ·
1 Parent(s): 6a414d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -37
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import StreamingResponse
@@ -13,11 +14,8 @@ import os
13
  from huggingface_hub import login
14
  from peft import PeftModel, PeftConfig
15
 
16
-
17
- # Create FastAPI app
18
  app = FastAPI()
19
 
20
- # CORS middleware setup
21
  app.add_middleware(
22
  CORSMiddleware,
23
  allow_origins=["*"],
@@ -26,40 +24,30 @@ app.add_middleware(
26
  allow_headers=["*"],
27
  )
28
 
29
- # Pydantic models
30
  class ChatRequest(BaseModel):
31
  message: str
32
- history: list = [] # List of [user_msg, assistant_msg] pairs
33
 
34
  class ChatResponse(BaseModel):
35
  response: str
36
 
37
- # Load model and tokenizer
38
- from peft import PeftModel, PeftConfig
39
-
40
-
41
  def load_model_and_tokenizer(base_model_name="mistralai/Mistral-7B-Instruct-v0.3", adapter_name="Danaasa/bible_mistral"):
42
- # Get the Hugging Face token from environment variable
43
  hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
44
-
45
- # Log in with the token if available
46
  if hf_token:
47
  login(token=hf_token)
48
  print("Successfully logged in with Hugging Face token")
49
  else:
50
  print("No Hugging Face token found in environment variables")
51
 
52
- # Load tokenizer with token for authentication
53
  tokenizer = AutoTokenizer.from_pretrained(
54
  base_model_name,
55
  trust_remote_code=True,
56
- token=hf_token # Pass token here
57
  )
58
 
59
  if tokenizer.pad_token_id is None:
60
  tokenizer.pad_token_id = tokenizer.eos_token_id
61
 
62
- # Set up quantization
63
  quantization_config = BitsAndBytesConfig(
64
  load_in_4bit=True,
65
  bnb_4bit_quant_type="nf4",
@@ -67,28 +55,25 @@ def load_model_and_tokenizer(base_model_name="mistralai/Mistral-7B-Instruct-v0.3
67
  bnb_4bit_compute_dtype=torch.float16
68
  )
69
 
70
- # Load the base model with token for authentication
71
  base_model = AutoModelForCausalLM.from_pretrained(
72
  base_model_name,
73
  quantization_config=quantization_config,
74
  device_map="auto",
75
  trust_remote_code=True,
76
- token=hf_token # Pass token here
77
  )
78
 
79
- # Load the adapter with token for authentication
80
  model = PeftModel.from_pretrained(
81
  base_model,
82
  adapter_name,
83
- token=hf_token # Pass token here
84
  )
85
 
86
  model.eval()
87
  return model, tokenizer
88
- # Global variables for model and tokenizer
89
  model, tokenizer = load_model_and_tokenizer()
90
 
91
- # Response generator
92
  def generate_response(question, conversation_history, model, tokenizer):
93
  system_prompt = """
94
  - You are a truthful Christian AI assistant.
@@ -105,9 +90,7 @@ def generate_response(question, conversation_history, model, tokenizer):
105
 
106
  input_text = f"[INST] {system_prompt} [/INST]\n"
107
 
108
- # Add conversation history if available
109
  if conversation_history:
110
- # Use the last 3 exchanges for context (can adjust as needed)
111
  recent_history = conversation_history[-3:]
112
  input_text += "Previous context (for reference only, do not repeat):\n"
113
  for user_msg, assistant_msg in recent_history:
@@ -133,8 +116,6 @@ def generate_response(question, conversation_history, model, tokenizer):
133
 
134
  try:
135
  answer = full_response.split("[/INST]")[-1].strip()
136
-
137
- # Clean known pieces
138
  if system_prompt in answer:
139
  answer = answer.replace(system_prompt, "").strip()
140
  if "Previous context" in answer:
@@ -143,13 +124,10 @@ def generate_response(question, conversation_history, model, tokenizer):
143
  answer = answer.split("Current question")[-1].strip()
144
  if question in answer[:len(question) + 10]:
145
  answer = answer.split(question)[-1].strip()
146
-
147
  if answer.startswith(("The assistant", "*The assistant")):
148
  answer = answer.split(".", 1)[-1].strip() if "." in answer else answer
149
-
150
  if answer.startswith('"') and answer.endswith('"'):
151
  answer = answer[1:-1].strip()
152
-
153
  except IndexError:
154
  print(f"Warning: Parsing failed, raw response: {full_response}")
155
  answer = full_response
@@ -159,22 +137,18 @@ def generate_response(question, conversation_history, model, tokenizer):
159
  for word in words:
160
  current_response += word + " "
161
  yield current_response.strip()
162
- time.sleep(0.05) # This controls typing speed
163
 
164
- # Stream response to client
165
  async def stream_response(message: str, conversation_history: List[Tuple[str, str]]):
166
  for response_chunk in generate_response(message, conversation_history, model, tokenizer):
167
- # Send each chunk as a server-sent event
168
  yield f"data: {json.dumps({'text': response_chunk})}\n\n"
169
- await asyncio.sleep(0.05) # Small delay to control flow
170
 
171
  @app.post("/chat")
172
  async def chat(request: ChatRequest):
173
  message = request.message
174
 
175
- # Process conversation history safely
176
  try:
177
- # Make sure each history item has exactly two elements (user_msg, assistant_msg)
178
  conversation_history = [
179
  (h[0], h[1]) for h in request.history
180
  if isinstance(h, list) and len(h) >= 2
@@ -183,7 +157,6 @@ async def chat(request: ChatRequest):
183
  print(f"Error processing history: {e}")
184
  conversation_history = []
185
 
186
- # Return a streaming response
187
  return StreamingResponse(
188
  stream_response(message, conversation_history),
189
  media_type="text/event-stream"
@@ -191,7 +164,6 @@ async def chat(request: ChatRequest):
191
 
192
  @app.post("/chat-full", response_model=ChatResponse)
193
  async def chat_full(request: ChatRequest):
194
- """Non-streaming endpoint as fallback"""
195
  message = request.message
196
 
197
  try:
@@ -203,7 +175,6 @@ async def chat_full(request: ChatRequest):
203
  print(f"Error processing history: {e}")
204
  conversation_history = []
205
 
206
- # Generate complete response
207
  response_text = ""
208
  for partial in generate_response(message, conversation_history, model, tokenizer):
209
  response_text = partial
 
1
+ # main.py (your code, unchanged except for the port in the CMD of the Dockerfile)
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
 
14
  from huggingface_hub import login
15
  from peft import PeftModel, PeftConfig
16
 
 
 
17
  app = FastAPI()
18
 
 
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
 
24
  allow_headers=["*"],
25
  )
26
 
 
27
  class ChatRequest(BaseModel):
28
  message: str
29
+ history: list = []
30
 
31
  class ChatResponse(BaseModel):
32
  response: str
33
 
 
 
 
 
34
  def load_model_and_tokenizer(base_model_name="mistralai/Mistral-7B-Instruct-v0.3", adapter_name="Danaasa/bible_mistral"):
 
35
  hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
 
 
36
  if hf_token:
37
  login(token=hf_token)
38
  print("Successfully logged in with Hugging Face token")
39
  else:
40
  print("No Hugging Face token found in environment variables")
41
 
 
42
  tokenizer = AutoTokenizer.from_pretrained(
43
  base_model_name,
44
  trust_remote_code=True,
45
+ token=hf_token
46
  )
47
 
48
  if tokenizer.pad_token_id is None:
49
  tokenizer.pad_token_id = tokenizer.eos_token_id
50
 
 
51
  quantization_config = BitsAndBytesConfig(
52
  load_in_4bit=True,
53
  bnb_4bit_quant_type="nf4",
 
55
  bnb_4bit_compute_dtype=torch.float16
56
  )
57
 
 
58
  base_model = AutoModelForCausalLM.from_pretrained(
59
  base_model_name,
60
  quantization_config=quantization_config,
61
  device_map="auto",
62
  trust_remote_code=True,
63
+ token=hf_token
64
  )
65
 
 
66
  model = PeftModel.from_pretrained(
67
  base_model,
68
  adapter_name,
69
+ token=hf_token
70
  )
71
 
72
  model.eval()
73
  return model, tokenizer
74
+
75
  model, tokenizer = load_model_and_tokenizer()
76
 
 
77
  def generate_response(question, conversation_history, model, tokenizer):
78
  system_prompt = """
79
  - You are a truthful Christian AI assistant.
 
90
 
91
  input_text = f"[INST] {system_prompt} [/INST]\n"
92
 
 
93
  if conversation_history:
 
94
  recent_history = conversation_history[-3:]
95
  input_text += "Previous context (for reference only, do not repeat):\n"
96
  for user_msg, assistant_msg in recent_history:
 
116
 
117
  try:
118
  answer = full_response.split("[/INST]")[-1].strip()
 
 
119
  if system_prompt in answer:
120
  answer = answer.replace(system_prompt, "").strip()
121
  if "Previous context" in answer:
 
124
  answer = answer.split("Current question")[-1].strip()
125
  if question in answer[:len(question) + 10]:
126
  answer = answer.split(question)[-1].strip()
 
127
  if answer.startswith(("The assistant", "*The assistant")):
128
  answer = answer.split(".", 1)[-1].strip() if "." in answer else answer
 
129
  if answer.startswith('"') and answer.endswith('"'):
130
  answer = answer[1:-1].strip()
 
131
  except IndexError:
132
  print(f"Warning: Parsing failed, raw response: {full_response}")
133
  answer = full_response
 
137
  for word in words:
138
  current_response += word + " "
139
  yield current_response.strip()
140
+ time.sleep(0.05)
141
 
 
142
  async def stream_response(message: str, conversation_history: List[Tuple[str, str]]):
143
  for response_chunk in generate_response(message, conversation_history, model, tokenizer):
 
144
  yield f"data: {json.dumps({'text': response_chunk})}\n\n"
145
+ await asyncio.sleep(0.05)
146
 
147
  @app.post("/chat")
148
  async def chat(request: ChatRequest):
149
  message = request.message
150
 
 
151
  try:
 
152
  conversation_history = [
153
  (h[0], h[1]) for h in request.history
154
  if isinstance(h, list) and len(h) >= 2
 
157
  print(f"Error processing history: {e}")
158
  conversation_history = []
159
 
 
160
  return StreamingResponse(
161
  stream_response(message, conversation_history),
162
  media_type="text/event-stream"
 
164
 
165
  @app.post("/chat-full", response_model=ChatResponse)
166
  async def chat_full(request: ChatRequest):
 
167
  message = request.message
168
 
169
  try:
 
175
  print(f"Error processing history: {e}")
176
  conversation_history = []
177
 
 
178
  response_text = ""
179
  for partial in generate_response(message, conversation_history, model, tokenizer):
180
  response_text = partial