natalieparker commited on
Commit
3512ce0
Β·
verified Β·
1 Parent(s): d393ff4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -19
app.py CHANGED
@@ -1,41 +1,89 @@
1
  from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
4
 
5
  app = FastAPI()
6
 
7
- MODEL = "natalieparker/LumaAI-160M-v3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- print("πŸ”„ Loading tokenizer...")
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
11
 
12
- print("πŸ”„ Loading model on CPU...")
13
- model = AutoModelForCausalLM.from_pretrained(
14
- MODEL,
15
- torch_dtype=torch.float32, # CPU only
16
- low_cpu_mem_usage=True
17
- )
18
- model.to("cpu")
19
 
20
  @app.get("/")
21
  def root():
22
- return {"status": "LumaAI API is live on CPU"}
 
23
 
24
  @app.post("/generate")
25
  def generate(prompt: str):
26
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
27
  with torch.no_grad():
28
  output = model.generate(
29
  **inputs,
30
  max_new_tokens=150,
31
- temperature=0.9,
32
  top_p=0.9,
33
- repetition_penalty=1.05,
34
- do_sample=True
 
35
  )
 
 
36
  text = tokenizer.decode(output[0], skip_special_tokens=True)
37
- return {"response": text}
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- if __name__ == "__main__":
40
- import uvicorn
41
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import os
5
+
6
+ # ==========================================
7
+ # 1. SETUP & MODEL LOADING (Executed once on startup)
8
+ # ==========================================
9
 
10
  app = FastAPI()
11
 
12
+ # CRITICAL: Since you merged the weights into the root folder,
13
+ # the model path inside the Hugging Face container is always '.'
14
+ MODEL_DIR = "."
15
+
16
+ # The model must be loaded into CPU memory first, then transferred if GPU is available.
17
+ # We are forcing CPU inference as you intended.
18
+ DEVICE = "cpu"
19
+
20
+ try:
21
+ print("πŸ”„ Loading tokenizer...")
22
+ # Use MODEL_DIR ('.') because files are in the root of the deployed app
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
24
+
25
+ print("πŸ”„ Loading model on CPU...")
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_DIR,
28
+ torch_dtype=torch.float32,
29
+ low_cpu_mem_usage=True # Optimization for CPU RAM
30
+ )
31
+ model.to(DEVICE)
32
+ print("βœ… Model loaded successfully on CPU.")
33
+
34
+ except Exception as e:
35
+ # If the model fails to load, print the error and set model to None
36
+ print(f"FATAL MODEL LOAD ERROR: {e}")
37
+ model = None
38
+ tokenizer = None
39
 
 
 
40
 
41
+ # ==========================================
42
+ # 2. ENDPOINTS
43
+ # ==========================================
 
 
 
 
44
 
45
  @app.get("/")
46
  def root():
47
+ """Health check endpoint."""
48
+ return {"status": "LumaAI API is live", "model_loaded": model is not None}
49
 
50
  @app.post("/generate")
51
  def generate(prompt: str):
52
+ """Generates text response from the model."""
53
+ if model is None:
54
+ return {"error": "Model failed to load during startup."}
55
+
56
+ # --- 1. PROMPT PREPARATION ---
57
+ # We use the final, clean format you trained on
58
+ formatted_prompt = f"User: {prompt}\nCharacter:"
59
+
60
+ # --- 2. INFERENCE ---
61
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
62
+
63
  with torch.no_grad():
64
  output = model.generate(
65
  **inputs,
66
  max_new_tokens=150,
67
+ temperature=0.75, # Balanced creativity
68
  top_p=0.9,
69
+ repetition_penalty=1.2,
70
+ do_sample=True,
71
+ pad_token_id=tokenizer.eos_token_id
72
  )
73
+
74
+ # --- 3. CLEANING ---
75
  text = tokenizer.decode(output[0], skip_special_tokens=True)
76
+
77
+ # Extract response after the Character: tag
78
+ if "Character:" in text:
79
+ response = text.split("Character:")[-1]
80
+ else:
81
+ response = text
82
+
83
+ # Clean up future user input and trim
84
+ response = response.split("User:")[0].strip()
85
+
86
+ # Final punctuation polish (from our earlier fixes)
87
+ response = response.replace(" .", ".").replace(" ,", ",").replace(" ?", "?").replace(" !", "!")
88
 
89
+ return {"response": response}