aledraa commited on
Commit
a1b4668
·
verified ·
1 Parent(s): 1ac97be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -11
app.py CHANGED
@@ -4,20 +4,47 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import json
6
  import random
 
7
  from typing import List, Optional
8
 
9
  app = FastAPI(title="Qwen Data Generator API")
10
 
11
- # Load model and tokenizer
 
 
12
  model_name = "Qwen/Qwen2.5-3B-Instruct"
13
- print("Loading model...")
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_name,
16
- torch_dtype="auto",
17
- device_map="auto"
18
- )
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class GenerationRequest(BaseModel):
23
  llm_commands: List[str]
@@ -47,6 +74,11 @@ JSON Array:"""
47
 
48
  @app.post("/generate", response_model=GenerationResponse)
49
  async def generate_data(request: GenerationRequest):
 
 
 
 
 
50
  try:
51
  # Set seed for reproducibility if provided
52
  if request.seed:
@@ -70,7 +102,11 @@ async def generate_data(request: GenerationRequest):
70
  )
71
 
72
  # Tokenize and generate
73
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
 
74
 
75
  with torch.no_grad():
76
  generated_ids = model.generate(
@@ -78,7 +114,8 @@ async def generate_data(request: GenerationRequest):
78
  max_new_tokens=2048,
79
  temperature=0.8,
80
  do_sample=True,
81
- pad_token_id=tokenizer.eos_token_id
 
82
  )
83
 
84
  # Decode response
 
4
  import torch
5
  import json
6
  import random
7
+ import os
8
  from typing import List, Optional
9
 
10
  app = FastAPI(title="Qwen Data Generator API")
11
 
12
+ # Global variables for model and tokenizer
13
+ model = None
14
+ tokenizer = None
15
  model_name = "Qwen/Qwen2.5-3B-Instruct"
16
+
17
+ def load_model():
18
+ """Load model and tokenizer with proper error handling"""
19
+ global model, tokenizer
20
+
21
+ try:
22
+ print("Loading model...")
23
+ print(f"Cache directory: {os.environ.get('HF_HOME', 'Not set')}")
24
+
25
+ # Load tokenizer first (smaller download)
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True
29
+ )
30
+ print("Tokenizer loaded successfully!")
31
+
32
+ # Load model with specific configurations for better compatibility
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_name,
35
+ torch_dtype=torch.float16, # Use float16 to save memory
36
+ device_map="auto",
37
+ trust_remote_code=True,
38
+ low_cpu_mem_usage=True
39
+ )
40
+ print("Model loaded successfully!")
41
+
42
+ except Exception as e:
43
+ print(f"Error loading model: {str(e)}")
44
+ raise e
45
+
46
+ # Load model on startup
47
+ load_model()
48
 
49
  class GenerationRequest(BaseModel):
50
  llm_commands: List[str]
 
74
 
75
  @app.post("/generate", response_model=GenerationResponse)
76
  async def generate_data(request: GenerationRequest):
77
+ global model, tokenizer
78
+
79
+ if model is None or tokenizer is None:
80
+ raise HTTPException(status_code=503, detail="Model not loaded")
81
+
82
  try:
83
  # Set seed for reproducibility if provided
84
  if request.seed:
 
102
  )
103
 
104
  # Tokenize and generate
105
+ model_inputs = tokenizer([text], return_tensors="pt")
106
+
107
+ # Move inputs to same device as model
108
+ if torch.cuda.is_available():
109
+ model_inputs = model_inputs.to('cuda')
110
 
111
  with torch.no_grad():
112
  generated_ids = model.generate(
 
114
  max_new_tokens=2048,
115
  temperature=0.8,
116
  do_sample=True,
117
+ pad_token_id=tokenizer.eos_token_id,
118
+ eos_token_id=tokenizer.eos_token_id
119
  )
120
 
121
  # Decode response