Testshded commited on
Commit
f599fda
Β·
verified Β·
1 Parent(s): 47fae8e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +119 -119
main.py CHANGED
@@ -1,120 +1,120 @@
1
- # main.py
2
-
3
- import logging
4
- from contextlib import asynccontextmanager
5
- import torch
6
- from fastapi import FastAPI, HTTPException
7
- from pydantic import BaseModel
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
-
10
- # --- Configuration ---
11
- # The repository ID for your model on the Hugging Face Hub
12
- HF_REPO_ID = "rxmha125/Rx_Codex_V1_Tiny"
13
- # Use GPU if available (CUDA), otherwise fallback to CPU
14
- MODEL_LOAD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- # --- Logging Setup ---
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
-
20
- # --- Global variables to hold the model and tokenizer ---
21
- model = None
22
- tokenizer = None
23
-
24
- # --- Application Lifespan (Model Loading) ---
25
- @asynccontextmanager
26
- async def lifespan(app: FastAPI):
27
- global model, tokenizer
28
- logger.info(f"API Startup: Loading model '{HF_REPO_ID}' to device '{MODEL_LOAD_DEVICE}'...")
29
-
30
- # Load the tokenizer from the Hub
31
- try:
32
- tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)
33
- logger.info("βœ… Tokenizer loaded successfully.")
34
- except Exception as e:
35
- logger.error(f"❌ FATAL: Tokenizer loading failed: {e}")
36
- # In a real app, you might want to handle this more gracefully
37
- # For Spaces, it will just fail to start, which is okay.
38
-
39
- # Load the model from the Hub
40
- try:
41
- model = AutoModelForCausalLM.from_pretrained(HF_REPO_ID)
42
- model.to(MODEL_LOAD_DEVICE)
43
- model.eval() # Set to evaluation mode for inference
44
- logger.info("βœ… Model loaded successfully.")
45
- except Exception as e:
46
- logger.error(f"❌ FATAL: Model loading failed: {e}")
47
-
48
- yield # The API is now running
49
-
50
- # --- Code below this line runs on shutdown ---
51
- logger.info("API Shutting down.")
52
- model = None
53
- tokenizer = None
54
-
55
-
56
- # --- Initialize FastAPI ---
57
- app = FastAPI(
58
- title="Rx Codex V1-Tiny API",
59
- description="An API for generating text with the Rx_Codex_V1_Tiny model.",
60
- lifespan=lifespan
61
- )
62
-
63
- # --- Pydantic Models for API Data Validation ---
64
- class GenerationRequest(BaseModel):
65
- prompt: str
66
- max_new_tokens: int = 150
67
- temperature: float = 0.7
68
- top_k: int = 50
69
-
70
- class GenerationResponse(BaseModel):
71
- generated_text: str
72
-
73
- # --- API Endpoints ---
74
- @app.get("/")
75
- def root():
76
- """A simple endpoint to check if the API is running."""
77
- status = "loaded" if model and tokenizer else "not loaded"
78
- return {"message": "Rx Codex V1-Tiny API is running", "model_status": status}
79
-
80
- @app.post("/generate", response_model=GenerationResponse)
81
- async def generate_text(request: GenerationRequest):
82
- """The main endpoint to generate text from a prompt."""
83
- if not model or not tokenizer:
84
- raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
85
-
86
- logger.info(f"Received generation request for prompt: '{request.prompt}'")
87
-
88
- # --- CRITICAL: Format the prompt correctly for the model ---
89
- formatted_prompt = f"### Human:\n{request.prompt}\n\n### Assistant:"
90
-
91
- # Prepare the input text for the model
92
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(MODEL_LOAD_DEVICE)
93
-
94
- # Generate text using the model
95
- with torch.no_grad():
96
- output_sequences = model.generate(
97
- input_ids=inputs["input_ids"],
98
- attention_mask=inputs["attention_mask"],
99
- max_new_tokens=request.max_new_tokens,
100
- temperature=request.temperature,
101
- top_k=request.top_k,
102
- do_sample=True,
103
- pad_token_id=tokenizer.eos_token_id
104
- )
105
-
106
- # Decode the generated tokens back into text
107
- full_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
108
-
109
- # Remove the original formatted prompt from the output to return only the new text
110
- generated_text = full_text[len(formatted_prompt):].strip()
111
-
112
- logger.info("Generation complete.")
113
- return GenerationResponse(generated_text=generated_text)
114
-
115
-
116
- # --- Uvicorn Runner (for local testing) ---
117
- if __name__ == "__main__":
118
- import uvicorn
119
- logger.info("Starting API locally via Uvicorn...")
120
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ # main.py
2
+
3
+ import logging
4
+ from contextlib import asynccontextmanager
5
+ import torch
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+ # --- Configuration ---
11
+ # The repository ID for your model on the Hugging Face Hub
12
+ HF_REPO_ID = "rxmha125/Rx_Codex_V1_Tiny_test"
13
+ # Use GPU if available (CUDA), otherwise fallback to CPU
14
+ MODEL_LOAD_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # --- Logging Setup ---
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # --- Global variables to hold the model and tokenizer ---
21
+ model = None
22
+ tokenizer = None
23
+
24
+ # --- Application Lifespan (Model Loading) ---
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI):
27
+ global model, tokenizer
28
+ logger.info(f"API Startup: Loading model '{HF_REPO_ID}' to device '{MODEL_LOAD_DEVICE}'...")
29
+
30
+ # Load the tokenizer from the Hub
31
+ try:
32
+ tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)
33
+ logger.info("βœ… Tokenizer loaded successfully.")
34
+ except Exception as e:
35
+ logger.error(f"❌ FATAL: Tokenizer loading failed: {e}")
36
+ # In a real app, you might want to handle this more gracefully
37
+ # For Spaces, it will just fail to start, which is okay.
38
+
39
+ # Load the model from the Hub
40
+ try:
41
+ model = AutoModelForCausalLM.from_pretrained(HF_REPO_ID)
42
+ model.to(MODEL_LOAD_DEVICE)
43
+ model.eval() # Set to evaluation mode for inference
44
+ logger.info("βœ… Model loaded successfully.")
45
+ except Exception as e:
46
+ logger.error(f"❌ FATAL: Model loading failed: {e}")
47
+
48
+ yield # The API is now running
49
+
50
+ # --- Code below this line runs on shutdown ---
51
+ logger.info("API Shutting down.")
52
+ model = None
53
+ tokenizer = None
54
+
55
+
56
+ # --- Initialize FastAPI ---
57
+ app = FastAPI(
58
+ title="Rx Codex V1-Tiny API",
59
+ description="An API for generating text with the Rx_Codex_V1_Tiny model.",
60
+ lifespan=lifespan
61
+ )
62
+
63
+ # --- Pydantic Models for API Data Validation ---
64
+ class GenerationRequest(BaseModel):
65
+ prompt: str
66
+ max_new_tokens: int = 150
67
+ temperature: float = 0.7
68
+ top_k: int = 50
69
+
70
+ class GenerationResponse(BaseModel):
71
+ generated_text: str
72
+
73
+ # --- API Endpoints ---
74
+ @app.get("/")
75
+ def root():
76
+ """A simple endpoint to check if the API is running."""
77
+ status = "loaded" if model and tokenizer else "not loaded"
78
+ return {"message": "Rx Codex V1-Tiny API is running", "model_status": status}
79
+
80
+ @app.post("/generate", response_model=GenerationResponse)
81
+ async def generate_text(request: GenerationRequest):
82
+ """The main endpoint to generate text from a prompt."""
83
+ if not model or not tokenizer:
84
+ raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
85
+
86
+ logger.info(f"Received generation request for prompt: '{request.prompt}'")
87
+
88
+ # --- CRITICAL: Format the prompt correctly for the model ---
89
+ formatted_prompt = f"### Human:\n{request.prompt}\n\n### Assistant:"
90
+
91
+ # Prepare the input text for the model
92
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(MODEL_LOAD_DEVICE)
93
+
94
+ # Generate text using the model
95
+ with torch.no_grad():
96
+ output_sequences = model.generate(
97
+ input_ids=inputs["input_ids"],
98
+ attention_mask=inputs["attention_mask"],
99
+ max_new_tokens=request.max_new_tokens,
100
+ temperature=request.temperature,
101
+ top_k=request.top_k,
102
+ do_sample=True,
103
+ pad_token_id=tokenizer.eos_token_id
104
+ )
105
+
106
+ # Decode the generated tokens back into text
107
+ full_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
108
+
109
+ # Remove the original formatted prompt from the output to return only the new text
110
+ generated_text = full_text[len(formatted_prompt):].strip()
111
+
112
+ logger.info("Generation complete.")
113
+ return GenerationResponse(generated_text=generated_text)
114
+
115
+
116
+ # --- Uvicorn Runner (for local testing) ---
117
+ if __name__ == "__main__":
118
+ import uvicorn
119
+ logger.info("Starting API locally via Uvicorn...")
120
  uvicorn.run(app, host="0.0.0.0", port=8000)