samdak93 commited on
Commit
9cc4ddd
·
1 Parent(s): 63bd460

chnage reqs

Browse files
Files changed (2) hide show
  1. app.py +23 -41
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,7 +1,6 @@
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
- import torch
4
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
  import os
6
  import re
7
  import logging
@@ -16,57 +15,40 @@ CORS(app)
16
 
17
  # Model loading
18
  MODEL_PATH = os.environ.get('MODEL_PATH', 'samdak93/qrit-2')
19
- tokenizer = None
20
- model = None
21
- device = None
22
 
23
  def load_model():
24
- global tokenizer, model, device
25
  logger.info(f"Loading model from {MODEL_PATH}")
26
  try:
27
- tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
28
- model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
 
29
  except Exception as e:
30
  logger.error(f"Error loading custom model: {e}")
31
  logger.info("Falling back to default GPT-2")
32
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
33
- model = GPT2LMHeadModel.from_pretrained("gpt2")
34
-
35
- tokenizer.pad_token = tokenizer.eos_token
36
- model.config.pad_token_id = model.config.eos_token_id
37
-
38
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
- model.to(device)
40
- logger.info(f"Model loaded on device: {device}")
41
 
42
  def generate_recipe(goal_prompt, max_length=300, temperature=0.7):
43
- global tokenizer, model, device
44
- if tokenizer is None or model is None:
45
  load_model()
46
 
47
  full_prompt = f"### Goal: {goal_prompt}\n### Recipe: "
48
- input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
49
- attention_mask = torch.ones_like(input_ids).to(device)
50
-
51
- with torch.no_grad():
52
- output = model.generate(
53
- input_ids=input_ids,
54
- attention_mask=attention_mask,
55
- max_length=max_length,
56
- temperature=temperature,
57
- do_sample=True,
58
- top_k=50,
59
- top_p=0.95,
60
- no_repeat_ngram_size=2,
61
- pad_token_id=tokenizer.eos_token_id
62
- )
63
-
64
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
65
- recipe_part = generated_text[len(full_prompt):]
66
- if "### Goal:" in recipe_part:
67
- recipe_part = recipe_part.split("### Goal:")[0].strip()
68
-
69
- return full_prompt + recipe_part
70
 
71
  def parse_recipe(recipe_text):
72
  recipe_data = {
 
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
+ from transformers import pipeline
 
4
  import os
5
  import re
6
  import logging
 
15
 
16
  # Model loading
17
  MODEL_PATH = os.environ.get('MODEL_PATH', 'samdak93/qrit-2')
18
+ generator = None
 
 
19
 
20
  def load_model():
21
+ global generator
22
  logger.info(f"Loading model from {MODEL_PATH}")
23
  try:
24
+ # Use Hugging Face pipeline for text generation
25
+ generator = pipeline("text-generation", model=MODEL_PATH)
26
+ logger.info("Model loaded successfully")
27
  except Exception as e:
28
  logger.error(f"Error loading custom model: {e}")
29
  logger.info("Falling back to default GPT-2")
30
+ # Fallback to default GPT-2 if model loading fails
31
+ generator = pipeline("text-generation", model="gpt2")
32
+ logger.info("Default GPT-2 model loaded")
 
 
 
 
 
 
33
 
34
  def generate_recipe(goal_prompt, max_length=300, temperature=0.7):
35
+ global generator
36
+ if generator is None:
37
  load_model()
38
 
39
  full_prompt = f"### Goal: {goal_prompt}\n### Recipe: "
40
+
41
+ # Generate recipe using the Hugging Face pipeline
42
+ try:
43
+ generated_text = generator(full_prompt, max_length=max_length, temperature=temperature, num_return_sequences=1)[0]['generated_text']
44
+ recipe_part = generated_text[len(full_prompt):]
45
+ if "### Goal:" in recipe_part:
46
+ recipe_part = recipe_part.split("### Goal:")[0].strip()
47
+
48
+ return full_prompt + recipe_part
49
+ except Exception as e:
50
+ logger.error(f"Error generating recipe: {str(e)}")
51
+ return str(e)
 
 
 
 
 
 
 
 
 
 
52
 
53
  def parse_recipe(recipe_text):
54
  recipe_data = {
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- flask==2.2.2 # Adjusted Flask version to a stable, widely used one
 
2
  flask_cors==5.2.0 # Updated flask_cors to the latest version
3
  transformers==4.30.0 # Updated transformers version (compatible with pipeline)
4
  torch==2.1.0 # Updated torch to a stable version (optional based on need for GPU support)
5
  numpy<2 # No change to numpy
6
- pyngrok==5.1.0 # Optional if you use pyngrok for ngrok tunneling
 
1
+ flask==3.1.0
2
+ flask_cors==5.0.1
3
  flask_cors==5.2.0 # Updated flask_cors to the latest version
4
  transformers==4.30.0 # Updated transformers version (compatible with pipeline)
5
  torch==2.1.0 # Updated torch to a stable version (optional based on need for GPU support)
6
  numpy<2 # No change to numpy