ElPremOoO commited on
Commit
e803dc8
·
verified ·
1 Parent(s): faeb5f2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +26 -10
main.py CHANGED
@@ -1,22 +1,38 @@
1
- from fastapi import FastAPI
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- app = FastAPI()
6
 
7
  # Load model and tokenizer
8
  model_name = "mistralai/Mistral-7B-v0.1"
 
 
 
 
 
 
 
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
11
 
12
- # if these 3 lines didn't work use the one above
13
- from transformers import BitsAndBytesConfig
14
- quant_config = BitsAndBytesConfig(load_in_8bit=True)
15
- model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quant_config, device_map="auto")
 
 
 
 
16
 
 
 
17
 
18
- @app.post("/generate")
19
- async def generate_text(prompt: str):
20
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
21
  outputs = model.generate(**inputs, max_length=200)
22
- return {"response": tokenizer.decode(outputs[0], skip_special_tokens=True)}
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ app = Flask(__name__)
6
 
7
  # Load model and tokenizer
8
  model_name = "mistralai/Mistral-7B-v0.1"
9
+
10
+ # Enable quantization for better performance on free-tier Spaces
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto"
15
+ )
16
+
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
18
 
19
+ @app.route("/")
20
+ def home():
21
+ return request.url
22
+
23
+ @app.route("/generate")
24
+ def generate_text():
25
+ data = request.get_json()
26
+ prompt = data.get("prompt", "")
27
 
28
+ if not prompt:
29
+ return jsonify({"error": "No prompt provided"}), 400
30
 
 
 
31
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
32
  outputs = model.generate(**inputs, max_length=200)
33
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+
35
+ return jsonify({"response": response_text})
36
+
37
+ if __name__ == "__main__":
38
+ app.run(host="0.0.0.0", port=7860)