rui3000 commited on
Commit
f17dc57
Β·
verified Β·
1 Parent(s): fc269b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -5
app.py CHANGED
@@ -1,10 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
 
4
- # Import the service - this should trigger GPU function registration
5
  from minimal_service import service, generate_text_gpu
6
 
7
- # Additional GPU function at app level for extra safety
8
  @spaces.GPU
9
  def app_gpu_test():
10
  """Test GPU function at app level"""
@@ -30,8 +107,8 @@ def generate_response(user_input):
30
  return f"Error: {str(e)}"
31
 
32
  # Create Gradio interface
33
- with gr.Blocks(title="Minimal GPU Test with FastAPI") as demo:
34
- gr.Markdown("# Minimal GPU Test with FastAPI")
35
  gr.Markdown("Testing if adding FastAPI breaks GPU detection.")
36
 
37
  with gr.Row():
@@ -53,7 +130,7 @@ with gr.Blocks(title="Minimal GPU Test with FastAPI") as demo:
53
  outputs=[output_text]
54
  )
55
 
56
- # ADD FASTAPI MOUNTING - Step 2 change
57
  app = FastAPI()
58
 
59
  @app.get("/")
 
1
+ # FILE 1: minimal_service.py (same as Step 1)
2
+ import spaces
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ # Global variables
7
+ _model = None
8
+ _tokenizer = None
9
+ _model_name = "microsoft/DialoGPT-small"
10
+
11
+ def initialize_tokenizer():
12
+ """Initialize tokenizer"""
13
+ global _tokenizer
14
+ if _tokenizer is None:
15
+ print("[MinimalService] Loading tokenizer...")
16
+ _tokenizer = AutoTokenizer.from_pretrained(_model_name)
17
+ if _tokenizer.pad_token is None:
18
+ _tokenizer.pad_token = _tokenizer.eos_token
19
+ print("[MinimalService] Tokenizer loaded successfully.")
20
+ return _tokenizer
21
+
22
+ @spaces.GPU
23
+ def generate_text_gpu(prompt: str, max_tokens: int = 50):
24
+ """GPU function for text generation"""
25
+ global _model, _tokenizer
26
+
27
+ print("[MinimalService] GPU function called")
28
+
29
+ # Initialize tokenizer
30
+ if _tokenizer is None:
31
+ initialize_tokenizer()
32
+
33
+ # Load model in GPU context
34
+ if _model is None:
35
+ print("[MinimalService] Loading model...")
36
+ _model = AutoModelForCausalLM.from_pretrained(
37
+ _model_name,
38
+ torch_dtype=torch.float16,
39
+ device_map="auto"
40
+ )
41
+ print("[MinimalService] Model loaded.")
42
+
43
+ # Simple generation
44
+ inputs = _tokenizer.encode(prompt, return_tensors="pt")
45
+ device = next(_model.parameters()).device
46
+ inputs = inputs.to(device)
47
+
48
+ with torch.no_grad():
49
+ outputs = _model.generate(
50
+ inputs,
51
+ max_new_tokens=max_tokens,
52
+ temperature=0.7,
53
+ do_sample=True,
54
+ pad_token_id=_tokenizer.eos_token_id
55
+ )
56
+
57
+ response = _tokenizer.decode(outputs[0], skip_special_tokens=True)
58
+ return response
59
+
60
+ class MinimalService:
61
+ def __init__(self):
62
+ print("[MinimalService] Service initialized")
63
+ initialize_tokenizer()
64
+
65
+ def generate(self, prompt: str):
66
+ """Public method to generate text"""
67
+ return generate_text_gpu(prompt)
68
+
69
+ # Create instance
70
+ service = MinimalService()
71
+
72
+ # Print confirmation
73
+ print(f"[MinimalService] GPU function available: {generate_text_gpu.__name__}")
74
+
75
+ # ====================================
76
+
77
+ # FILE 2: app.py (Step 2 - with FastAPI)
78
  import gradio as gr
79
  import spaces
80
 
81
+ # Import the service
82
  from minimal_service import service, generate_text_gpu
83
 
84
+ # Additional GPU function at app level
85
  @spaces.GPU
86
  def app_gpu_test():
87
  """Test GPU function at app level"""
 
107
  return f"Error: {str(e)}"
108
 
109
  # Create Gradio interface
110
+ with gr.Blocks(title="Step 2: FastAPI Test") as demo:
111
+ gr.Markdown("# Step 2: Testing FastAPI + GPU")
112
  gr.Markdown("Testing if adding FastAPI breaks GPU detection.")
113
 
114
  with gr.Row():
 
130
  outputs=[output_text]
131
  )
132
 
133
+ # ADD FASTAPI MOUNTING
134
  app = FastAPI()
135
 
136
  @app.get("/")