hydffgg commited on
Commit
585b80d
·
verified ·
1 Parent(s): ece78af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -1,19 +1,40 @@
1
- from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
3
 
4
- app = FastAPI()
 
5
 
6
- MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- @app.post("/chat")
12
- def chat(prompt: str):
13
  inputs = tokenizer(prompt, return_tensors="pt")
14
  outputs = model.generate(
15
  **inputs,
16
- max_new_tokens=200
17
  )
18
- reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
- return {"reply": reply}
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import os
5
 
6
+ MODEL_ID = "google/gemma-3-270m-it"
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
 
9
+ tokenizer = None
10
+ model = None
11
 
12
+ def load_model():
13
+ global tokenizer, model
14
+ if tokenizer is None or model is None:
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ MODEL_ID,
17
+ token=HF_TOKEN
18
+ )
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ MODEL_ID,
21
+ token=HF_TOKEN,
22
+ torch_dtype=torch.float32,
23
+ low_cpu_mem_usage=True
24
+ )
25
 
26
+ def chat(prompt):
27
+ load_model()
28
  inputs = tokenizer(prompt, return_tensors="pt")
29
  outputs = model.generate(
30
  **inputs,
31
+ max_new_tokens=128
32
  )
33
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+
35
+ gr.Interface(
36
+ fn=chat,
37
+ inputs="textbox",
38
+ outputs="textbox",
39
+ title="Gemma3 270M Cloud"
40
+ ).launch(server_name="0.0.0.0")