moslem commited on
Commit
bec1bb2
·
verified ·
1 Parent(s): e218766

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -185
app.py CHANGED
@@ -1,185 +1,85 @@
1
- """
2
- app.py
3
- FastAPI application for serving either:
4
- - a text-generation LLM, or
5
- - a visual-language model (VLM) for image captioning.
6
-
7
- Environment variables:
8
- MODEL_ID — Hugging Face model repo id (default: "gpt2")
9
- MODEL_TYPE — "llm" or "vlm" (default: "llm")
10
- TRUST_REMOTE_CODE — "true"/"false" for custom model code
11
- """
12
-
13
- import os
14
- import io
15
- import asyncio
16
- import logging
17
- from typing import Optional
18
-
19
- import torch
20
- from PIL import Image
21
- from fastapi import FastAPI, UploadFile, File, HTTPException
22
- from pydantic import BaseModel
23
- from transformers import pipeline
24
- from transformers.pipelines import Pipeline
25
-
26
-
27
- # -------------------------------------------------------------------------
28
- # Configuration
29
- # -------------------------------------------------------------------------
30
- MODEL_ID = os.environ.get("MODEL_ID", "gpt2")
31
- MODEL_TYPE = os.environ.get("MODEL_TYPE", "llm").lower() # "llm" or "vlm"
32
- TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in (
33
- "1",
34
- "true",
35
- "yes",
36
- )
37
-
38
- # Logging setup
39
- logging.basicConfig(level=logging.INFO)
40
- logger = logging.getLogger("hf-fastapi")
41
-
42
- # FastAPI instance
43
- app = FastAPI(title="Hugging Face FastAPI LLM/VLM Demo")
44
-
45
- # Lazy-loaded model pipeline
46
- pipe: Optional[Pipeline] = None
47
- load_error: Optional[str] = None
48
-
49
-
50
- # -------------------------------------------------------------------------
51
- # Helper functions
52
- # -------------------------------------------------------------------------
53
- def get_device() -> int:
54
- """Return CUDA device index if available, else CPU (-1)."""
55
- return 0 if torch.cuda.is_available() else -1
56
-
57
-
58
- async def run_blocking(func, *args, **kwargs):
59
- """Run blocking pipeline calls in a thread pool."""
60
- loop = asyncio.get_event_loop()
61
- return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
62
-
63
-
64
- # -------------------------------------------------------------------------
65
- # Model loading
66
- # -------------------------------------------------------------------------
67
- @app.on_event("startup")
68
- def load_model():
69
- """Load model pipeline on startup."""
70
- global pipe, load_error
71
-
72
- device = get_device()
73
- try:
74
- logger.info(f"Loading model '{MODEL_ID}' ({MODEL_TYPE}) on device {device}...")
75
-
76
- if MODEL_TYPE == "vlm":
77
- pipe = pipeline(
78
- "image-to-text",
79
- model=MODEL_ID,
80
- device=device,
81
- trust_remote_code=TRUST_REMOTE_CODE,
82
- )
83
- else:
84
- pipe = pipeline(
85
- "text-generation",
86
- model=MODEL_ID,
87
- device=device,
88
- trust_remote_code=TRUST_REMOTE_CODE,
89
- )
90
-
91
- logger.info("✅ Model loaded successfully.")
92
- except Exception as e:
93
- load_error = str(e)
94
- logger.exception("❌ Failed to load model: %s", e)
95
-
96
-
97
- # -------------------------------------------------------------------------
98
- # API models
99
- # -------------------------------------------------------------------------
100
- class TextRequest(BaseModel):
101
- prompt: str
102
- max_new_tokens: Optional[int] = 64
103
- do_sample: Optional[bool] = False
104
- temperature: Optional[float] = 0.7
105
-
106
-
107
- # -------------------------------------------------------------------------
108
- # Routes
109
- # -------------------------------------------------------------------------
110
- @app.get("/", tags=["health"])
111
- def root():
112
- """Root endpoint showing model info."""
113
- return {
114
- "status": "ok",
115
- "model_id": MODEL_ID,
116
- "model_type": MODEL_TYPE,
117
- "device": "cuda" if torch.cuda.is_available() else "cpu",
118
- "model_loaded": pipe is not None,
119
- "load_error": load_error,
120
- }
121
-
122
-
123
- @app.get("/health", tags=["health"])
124
- def health():
125
- """Simple health check."""
126
- if load_error:
127
- return {"status": "error", "detail": load_error}
128
- return {"status": "healthy"}
129
-
130
-
131
- @app.post("/generate-text", tags=["text"])
132
- async def generate_text(req: TextRequest):
133
- """Generate text using an LLM."""
134
- if MODEL_TYPE == "vlm":
135
- raise HTTPException(status_code=400, detail="Model is VLM. Use /image-caption.")
136
- if pipe is None:
137
- raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading...'}")
138
-
139
- try:
140
- outputs = await run_blocking(
141
- pipe,
142
- req.prompt,
143
- max_new_tokens=req.max_new_tokens,
144
- do_sample=req.do_sample,
145
- temperature=req.temperature,
146
- return_full_text=False,
147
- )
148
- except Exception as e:
149
- logger.exception("Generation failed: %s", e)
150
- raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
151
-
152
- if isinstance(outputs, list) and outputs:
153
- text_out = outputs[0].get("generated_text") or outputs[0].get("text") or str(outputs[0])
154
- else:
155
- text_out = str(outputs)
156
-
157
- return {"generated_text": text_out}
158
-
159
-
160
- @app.post("/image-caption", tags=["image"])
161
- async def image_caption(file: UploadFile = File(...)):
162
- """Caption an uploaded image using a VLM."""
163
- if MODEL_TYPE != "vlm":
164
- raise HTTPException(status_code=400, detail="Model is LLM. Set MODEL_TYPE=vlm.")
165
- if pipe is None:
166
- raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading...'}")
167
-
168
- try:
169
- contents = await file.read()
170
- img = Image.open(io.BytesIO(contents)).convert("RGB")
171
- except Exception as e:
172
- raise HTTPException(status_code=400, detail=f"Invalid image file: {e}")
173
-
174
- try:
175
- outputs = await run_blocking(pipe, img)
176
- except Exception as e:
177
- logger.exception("Captioning failed: %s", e)
178
- raise HTTPException(status_code=500, detail=f"Captioning failed: {e}")
179
-
180
- if isinstance(outputs, list) and outputs:
181
- caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0])
182
- else:
183
- caption = str(outputs)
184
-
185
- return {"caption": caption}
 
1
+ # app.py
2
+ from fastapi import FastAPI, UploadFile, File, Request, HTTPException
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
+ import os
7
+ import io
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import pipeline
11
+ import asyncio
12
+
13
+ # غیرفعال کردن خودکار داکز (نمای Swagger) برای کاربران:
14
+ app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
15
+
16
+ # استاتیک و تمپلیت
17
+ app.mount("/static", StaticFiles(directory="static"), name="static")
18
+ templates = Jinja2Templates(directory="templates")
19
+
20
+ MODEL_ID = os.environ.get("MODEL_ID", "gpt2")
21
+ MODEL_TYPE = os.environ.get("MODEL_TYPE", "llm").lower()
22
+ TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1","true","yes")
23
+
24
+ pipe = None
25
+ load_error = None
26
+
27
+ def get_device():
28
+ return 0 if torch.cuda.is_available() else -1
29
+
30
+ @app.on_event("startup")
31
+ def load_model():
32
+ global pipe, load_error
33
+ try:
34
+ if MODEL_TYPE == "vlm":
35
+ pipe = pipeline("image-to-text", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE)
36
+ else:
37
+ pipe = pipeline("text-generation", model=MODEL_ID, device=get_device(), trust_remote_code=TRUST_REMOTE_CODE)
38
+ except Exception as e:
39
+ load_error = str(e)
40
+
41
+ async def run_blocking(func, *args, **kwargs):
42
+ loop = asyncio.get_event_loop()
43
+ return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
44
+
45
+ # صفحهٔ اصلی برای کاربران — فرمِ ساده
46
+ @app.get("/", response_class=HTMLResponse)
47
+ def index(request: Request):
48
+ return templates.TemplateResponse("index.html", {"request": request, "model_type": MODEL_TYPE, "model_id": MODEL_ID, "model_loaded": pipe is not None})
49
+
50
+ # generate-text (همانند قبل، اما async)
51
+ @app.post("/generate-text")
52
+ async def generate_text(payload: dict):
53
+ if MODEL_TYPE == "vlm":
54
+ raise HTTPException(status_code=400, detail="Model is VLM. Use /image-caption.")
55
+ if pipe is None:
56
+ raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}")
57
+ prompt = payload.get("prompt", "")
58
+ max_new_tokens = payload.get("max_new_tokens", 64)
59
+ do_sample = payload.get("do_sample", False)
60
+ temperature = payload.get("temperature", 0.7)
61
+ outputs = await run_blocking(pipe, prompt, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, return_full_text=False)
62
+ if isinstance(outputs, list) and outputs:
63
+ text_out = outputs[0].get("generated_text") or outputs[0].get("text") or str(outputs[0])
64
+ else:
65
+ text_out = str(outputs)
66
+ return {"generated_text": text_out}
67
+
68
+ # image-caption
69
+ @app.post("/image-caption")
70
+ async def image_caption(file: UploadFile = File(...)):
71
+ if MODEL_TYPE != "vlm":
72
+ raise HTTPException(status_code=400, detail="Model is LLM. Set MODEL_TYPE=vlm.")
73
+ if pipe is None:
74
+ raise HTTPException(status_code=503, detail=f"Model not loaded: {load_error or 'loading'}")
75
+ contents = await file.read()
76
+ try:
77
+ img = Image.open(io.BytesIO(contents)).convert("RGB")
78
+ except Exception as e:
79
+ raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
80
+ outputs = await run_blocking(pipe, img)
81
+ if isinstance(outputs, list) and outputs:
82
+ caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0])
83
+ else:
84
+ caption = str(outputs)
85
+ return {"caption": caption}