Zeldeo commited on
Commit
3a68ffb
·
verified ·
1 Parent(s): 0d026f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -40
app.py CHANGED
@@ -1,12 +1,18 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
4
  import torch
5
- from PIL import Image
6
- from transformers import AutoImageProcessor, AutoBackbone
7
- import io
8
 
9
- app = FastAPI(title="FAST TextNet API")
 
 
 
 
 
 
10
 
11
  app.add_middleware(
12
  CORSMiddleware,
@@ -15,44 +21,20 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
18
- print("Chargement du modèle...")
19
- processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
20
- model = AutoBackbone.from_pretrained("czczup/textnet-base")
21
- model.eval()
22
- print("Modèle prêt !")
23
 
24
  @app.get("/")
25
  def health():
26
- return {"status": "ok", "model": "czczup/textnet-base"}
27
 
28
- @app.post("/detect")
29
- async def detect_text(file: UploadFile = File(...)):
30
  try:
31
- image_bytes = await file.read()
32
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
33
-
34
- inputs = processor(image, return_tensors="pt")
35
-
36
  with torch.no_grad():
37
- outputs = model(**inputs)
38
-
39
- # Récupère tous les feature maps
40
- feature_maps = []
41
- for fm in outputs.feature_maps:
42
- feature_maps.append({
43
- "shape": list(fm.shape),
44
- "mean": float(fm.mean()),
45
- "std": float(fm.std()),
46
- "min": float(fm.min()),
47
- "max": float(fm.max()),
48
- })
49
-
50
- return JSONResponse({
51
- "success": True,
52
- "image_size": list(image.size),
53
- "feature_maps": feature_maps,
54
- "num_stages": len(feature_maps),
55
- })
56
-
57
  except Exception as e:
58
- return JSONResponse({"success": False, "error": str(e)}, status_code=500)
 
1
+ # app.py
2
+ from fastapi import FastAPI
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import torch
 
 
 
8
 
9
+ # Modèle HF Flan-T5
10
+ MODEL_NAME = "google/flan-t5-base"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
13
+ model.eval()
14
+
15
+ app = FastAPI(title="Flan-T5 Service")
16
 
17
  app.add_middleware(
18
  CORSMiddleware,
 
21
  allow_headers=["*"],
22
  )
23
 
24
+ class ImproveRequest(BaseModel):
25
+ text: str
 
 
 
26
 
27
  @app.get("/")
28
  def health():
29
+ return {"status": "ok", "model": MODEL_NAME}
30
 
31
+ @app.post("/improve")
32
+ def improve_text(req: ImproveRequest):
33
  try:
34
+ inputs = tokenizer(req.text, return_tensors="pt")
 
 
 
 
35
  with torch.no_grad():
36
+ outputs = model.generate(**inputs, max_new_tokens=512)
37
+ improved = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+ return JSONResponse({"success": True, "improved_text": improved})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
+ return JSONResponse({"success": False, "error": str(e)}, status_code=500)