yukee1992 commited on
Commit
04a5c1f
·
verified ·
1 Parent(s): c56dee0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -44
app.py CHANGED
@@ -2,8 +2,8 @@ import os
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
- from fastapi import FastAPI
6
- from typing import Dict, Any
7
 
8
  # Configuration
9
  MODEL_ID = "google/gemma-1.1-2b-it"
@@ -11,7 +11,6 @@ HF_TOKEN = os.getenv("HF_TOKEN", "")
11
  MAX_TOKENS = 80
12
  MAX_INPUT_LENGTH = 100
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
- PORT = int(os.getenv("PORT", 7860)) # Default port with override
15
 
16
  class ScriptGenerator:
17
  def __init__(self):
@@ -20,31 +19,29 @@ class ScriptGenerator:
20
  self.loaded = False
21
 
22
  def load_model(self):
23
- """Safe model loading with progress tracking"""
24
  if self.loaded:
25
  return
26
 
27
  print("🔄 Loading model...")
28
  try:
29
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
30
-
31
  self.model = AutoModelForCausalLM.from_pretrained(
32
  MODEL_ID,
33
  torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16,
34
  device_map="auto" if DEVICE == "cuda" else None,
35
  token=HF_TOKEN
36
  ).to(DEVICE)
37
-
38
  self.loaded = True
39
- print("✅ Model loaded successfully!")
40
  except Exception as e:
41
- print(f"❌ Model loading failed: {str(e)}")
42
  raise
43
 
44
  generator = ScriptGenerator()
45
 
46
  def predict(topic: str) -> str:
47
- """Generate script with proper error handling"""
48
  try:
49
  if not topic or len(topic) > MAX_INPUT_LENGTH:
50
  return f"Topic must be 1-{MAX_INPUT_LENGTH} characters"
@@ -62,52 +59,44 @@ def predict(topic: str) -> str:
62
  temperature=0.7,
63
  do_sample=True
64
  )
65
-
66
  return generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
67
 
68
- except torch.cuda.OutOfMemoryError:
69
- return "Error: GPU out of memory - try a shorter input"
70
  except Exception as e:
71
  return f"Error: {str(e)}"
72
 
73
- # Create Gradio interface
74
- interface = gr.Interface(
75
- fn=predict,
76
- inputs=gr.Textbox(label="Topic", placeholder="Enter your topic..."),
77
- outputs=gr.Textbox(label="Generated Script", lines=5),
78
- title="Gemma Script Generator",
79
- allow_flagging="never"
80
- )
81
-
82
- # Create FastAPI app
83
  app = FastAPI()
84
 
85
  # Add API endpoint
86
  @app.post("/api/predict")
87
- async def api_predict(topic: str):
88
- return {
89
- "success": True,
90
- "result": predict(topic),
91
- "error": None
92
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # Mount Gradio interface
95
  app = gr.mount_gradio_app(app, interface, path="/")
96
 
97
- # Launch configuration
98
  if __name__ == "__main__":
99
  generator.load_model()
100
-
101
- # Disable Gradio's internal port scanning
102
- os.environ["GRADIO_SERVER_PORT"] = str(PORT)
103
- os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
104
-
105
- interface.launch(
106
- server_name="0.0.0.0",
107
- server_port=PORT,
108
- share=False,
109
- prevent_thread_lock=True, # Required for Hugging Face Spaces
110
- show_error=True,
111
- debug=False, # Disable debug mode to prevent port scanning
112
- ssl_verify=False # Disable SSL verification for internal calls
113
- )
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.responses import JSONResponse
7
 
8
  # Configuration
9
  MODEL_ID = "google/gemma-1.1-2b-it"
 
11
  MAX_TOKENS = 80
12
  MAX_INPUT_LENGTH = 100
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
14
 
15
  class ScriptGenerator:
16
  def __init__(self):
 
19
  self.loaded = False
20
 
21
  def load_model(self):
22
+ """Safe model loading"""
23
  if self.loaded:
24
  return
25
 
26
  print("🔄 Loading model...")
27
  try:
28
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
 
29
  self.model = AutoModelForCausalLM.from_pretrained(
30
  MODEL_ID,
31
  torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16,
32
  device_map="auto" if DEVICE == "cuda" else None,
33
  token=HF_TOKEN
34
  ).to(DEVICE)
 
35
  self.loaded = True
36
+ print("✅ Model loaded!")
37
  except Exception as e:
38
+ print(f"❌ Loading failed: {str(e)}")
39
  raise
40
 
41
  generator = ScriptGenerator()
42
 
43
  def predict(topic: str) -> str:
44
+ """Generate script with error handling"""
45
  try:
46
  if not topic or len(topic) > MAX_INPUT_LENGTH:
47
  return f"Topic must be 1-{MAX_INPUT_LENGTH} characters"
 
59
  temperature=0.7,
60
  do_sample=True
61
  )
 
62
  return generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
63
 
 
 
64
  except Exception as e:
65
  return f"Error: {str(e)}"
66
 
67
+ # Create FastAPI app first
 
 
 
 
 
 
 
 
 
68
  app = FastAPI()
69
 
70
  # Add API endpoint
71
  @app.post("/api/predict")
72
+ async def api_predict(request: Request):
73
+ try:
74
+ data = await request.json()
75
+ topic = data.get("topic", "")
76
+ return JSONResponse({
77
+ "success": True,
78
+ "result": predict(topic),
79
+ "error": None
80
+ })
81
+ except Exception as e:
82
+ return JSONResponse({
83
+ "success": False,
84
+ "result": None,
85
+ "error": str(e)
86
+ }, status_code=500)
87
+
88
+ # Create Gradio interface
89
+ interface = gr.Interface(
90
+ fn=predict,
91
+ inputs=gr.Textbox(label="Topic"),
92
+ outputs=gr.Textbox(label="Script", lines=5),
93
+ title="Gemma Script Generator"
94
+ )
95
 
96
+ # Mount Gradio app
97
  app = gr.mount_gradio_app(app, interface, path="/")
98
 
 
99
  if __name__ == "__main__":
100
  generator.load_model()
101
+ import uvicorn
102
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))