Arni1ntares commited on
Commit
a78f386
·
1 Parent(s): 345716b

Deploy FastAPI model app

Browse files
Files changed (4) hide show
  1. app.py +18 -0
  2. model_inference.py +26 -0
  3. requirements.txt +4 -0
  4. space.yaml +8 -0
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
+ from model_inference import generate_code
4
+
5
+ app = FastAPI()
6
+
7
+ class Prompt(BaseModel):
8
+ prompt: str
9
+ max_tokens: int = 128
10
+
11
+ @app.get("/")
12
+ def home():
13
+ return {"message": "🧠 Model is online!"}
14
+
15
+ @app.post("/generate")
16
+ def generate(prompt: Prompt):
17
+ output = generate_code(prompt.prompt, prompt.max_tokens)
18
+ return {"output": output}
model_inference.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ MODEL_NAME = "NousResearch/Hermes-2-Pro-Mistral" # ✅ Uncensored & efficient
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ MODEL_NAME,
9
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
10
+ device_map="auto"
11
+ )
12
+ model.eval()
13
+
14
+ def generate_code(prompt: str, max_tokens: int = 256):
15
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
16
+ with torch.no_grad():
17
+ output = model.generate(
18
+ inputs.input_ids,
19
+ max_new_tokens=max_tokens,
20
+ do_sample=True,
21
+ temperature=0.7,
22
+ top_p=0.95,
23
+ repetition_penalty=1.1,
24
+ pad_token_id=tokenizer.eos_token_id
25
+ )
26
+ return tokenizer.decode(output[0], skip_special_tokens=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers>=4.38.0
2
+ torch
3
+ fastapi
4
+ uvicorn
space.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ ## ✅ 6. (Optional) `Space.yaml` (for Hugging Face Spaces)
4
+
5
+ ```yaml
6
+ sdk: docker
7
+ app_file: app.py
8
+ python_version: "3.9"